buleprint.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import asyncio
  2. import traceback
  3. import uuid
  4. from typing import Dict, Any
  5. from quart import Blueprint, jsonify, request
  6. from quart_cors import cors
  7. from applications.config import (
  8. DEFAULT_MODEL,
  9. LOCAL_MODEL_CONFIG,
  10. BASE_MILVUS_SEARCH_PARAMS,
  11. )
  12. from applications.resource import get_resource_manager
  13. from applications.api import get_basic_embedding
  14. from applications.api import get_img_embedding
  15. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  16. from applications.search import HybridSearch
  17. from applications.utils.chat import ChatClassifier
  18. from applications.utils.mysql import Dataset, Contents, ContentChunks
  19. server_bp = Blueprint("api", __name__, url_prefix="/api")
  20. server_bp = cors(server_bp, allow_origin="*")
  21. @server_bp.route("/embed", methods=["POST"])
  22. async def embed():
  23. body = await request.get_json()
  24. text = body.get("text")
  25. model_name = body.get("model", DEFAULT_MODEL)
  26. if not LOCAL_MODEL_CONFIG.get(model_name):
  27. return jsonify({"error": "error model"})
  28. embedding = await get_basic_embedding(text, model_name)
  29. return jsonify({"embedding": embedding})
  30. @server_bp.route("/img_embed", methods=["POST"])
  31. async def img_embed():
  32. body = await request.get_json()
  33. url_list = body.get("url_list")
  34. if not url_list:
  35. return jsonify({"error": "error url_list"})
  36. embedding = await get_img_embedding(url_list)
  37. return jsonify(embedding)
  38. @server_bp.route("/delete", methods=["POST"])
  39. async def delete():
  40. body = await request.get_json()
  41. level = body.get("level")
  42. params = body.get("params")
  43. if not level or not params:
  44. return jsonify({"error": "error level or params"})
  45. resource = get_resource_manager()
  46. delete_task = DeleteTask(resource)
  47. response = await delete_task.deal(level, params)
  48. return jsonify(response)
  49. @server_bp.route("/chunk", methods=["POST"])
  50. async def chunk():
  51. body = await request.get_json()
  52. text = body.get("text", "")
  53. ori_doc_id = body.get("doc_id")
  54. text = text.strip()
  55. if not text:
  56. return jsonify({"error": "error text"})
  57. resource = get_resource_manager()
  58. # generate doc id
  59. if ori_doc_id:
  60. body["re_chunk"] = True
  61. doc_id = ori_doc_id
  62. else:
  63. doc_id = f"doc-{uuid.uuid4()}"
  64. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  65. doc_id = await chunk_task.deal(body)
  66. return jsonify({"doc_id": doc_id})
  67. @server_bp.route("/search", methods=["POST"])
  68. async def search():
  69. """
  70. filters: Dict[str, Any], # 条件过滤
  71. query_vec: List[float], # query 的向量
  72. anns_field: str = "vector_text", # query指定的向量空间
  73. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  74. query_text: str = None, #是否通过 topic 倒排
  75. _source=False, # 是否返回元数据
  76. es_size: int = 10000, #es 第一层过滤数量
  77. sort_by: str = None, # 排序
  78. milvus_size: int = 10 # milvus粗排返回数量
  79. :return:
  80. """
  81. body = await request.get_json()
  82. # 解析数据
  83. search_type: str = body.get("search_type")
  84. filters: Dict[str, Any] = body.get("filters", {})
  85. anns_field: str = body.get("anns_field", "vector_text")
  86. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  87. query_text: str = body.get("query_text")
  88. _source: bool = body.get("_source", False)
  89. es_size: int = body.get("es_size", 10000)
  90. sort_by: str = body.get("sort_by")
  91. milvus_size: int = body.get("milvus", 20)
  92. limit: int = body.get("limit", 10)
  93. if not query_text:
  94. return jsonify({"error": "error query_text"})
  95. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  96. resource = get_resource_manager()
  97. search_engine = HybridSearch(
  98. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  99. )
  100. try:
  101. match search_type:
  102. case "base":
  103. response = await search_engine.base_vector_search(
  104. query_vec=query_vector,
  105. anns_field=anns_field,
  106. search_params=search_params,
  107. limit=limit,
  108. )
  109. return jsonify(response), 200
  110. case "hybrid":
  111. response = await search_engine.hybrid_search(
  112. filters=filters,
  113. query_vec=query_vector,
  114. anns_field=anns_field,
  115. search_params=search_params,
  116. es_size=es_size,
  117. sort_by=sort_by,
  118. milvus_size=milvus_size,
  119. )
  120. return jsonify(response), 200
  121. case "strategy":
  122. return jsonify({"error": "strategy not implemented"}), 405
  123. case _:
  124. return jsonify({"error": "error search_type"}), 200
  125. except Exception as e:
  126. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  127. @server_bp.route("/dataset/list", methods=["GET"])
  128. async def dataset_list():
  129. resource = get_resource_manager()
  130. datasets = await Dataset(resource.mysql_client).select_dataset()
  131. # 创建所有任务
  132. tasks = [
  133. Contents(resource.mysql_client).select_count(dataset["id"])
  134. for dataset in datasets
  135. ]
  136. counts = await asyncio.gather(*tasks)
  137. # 组装数据
  138. data_list = [
  139. {
  140. "dataset_id": dataset["id"],
  141. "name": dataset["name"],
  142. "count": count,
  143. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  144. }
  145. for dataset, count in zip(datasets, counts)
  146. ]
  147. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  148. @server_bp.route("/dataset/add", methods=["POST"])
  149. async def add_dataset():
  150. resource = get_resource_manager()
  151. dataset = Dataset(resource.mysql_client)
  152. # 从请求体里取参数
  153. body = await request.get_json()
  154. name = body.get("name")
  155. if not name:
  156. return jsonify({"status_code": 400, "detail": "name is required"})
  157. # 执行新增
  158. await dataset.add_dataset(name)
  159. return jsonify({"status_code": 200, "detail": "success"})
  160. @server_bp.route("/content/get", methods=["GET"])
  161. async def get_content():
  162. resource = get_resource_manager()
  163. contents = Contents(resource.mysql_client)
  164. # 获取请求参数
  165. doc_id = request.args.get("docId")
  166. if not doc_id:
  167. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  168. # 查询内容
  169. rows = await contents.select_content_by_doc_id(doc_id)
  170. if not rows:
  171. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  172. row = rows[0]
  173. return jsonify(
  174. {
  175. "status_code": 200,
  176. "detail": "success",
  177. "data": {
  178. "title": row.get("title", ""),
  179. "text": row.get("text", ""),
  180. "doc_id": row.get("doc_id", ""),
  181. },
  182. }
  183. )
  184. @server_bp.route("/content/list", methods=["GET"])
  185. async def content_list():
  186. resource = get_resource_manager()
  187. contents = Contents(resource.mysql_client)
  188. # 从 URL 查询参数获取分页和过滤参数
  189. page_num = int(request.args.get("page", 1))
  190. page_size = int(request.args.get("pageSize", 10))
  191. dataset_id = request.args.get("datasetId")
  192. doc_status = int(request.args.get("doc_status", 1))
  193. # order_by 可以用 JSON 字符串传递
  194. import json
  195. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  196. try:
  197. order_by = json.loads(order_by_str)
  198. except Exception:
  199. order_by = {"id": "desc"}
  200. # 调用 select_contents,获取分页字典
  201. result = await contents.select_contents(
  202. page_num=page_num,
  203. page_size=page_size,
  204. dataset_id=dataset_id,
  205. doc_status=doc_status,
  206. order_by=order_by,
  207. )
  208. # 格式化 entities,只保留必要字段
  209. entities = [
  210. {
  211. "doc_id": row["doc_id"],
  212. "title": row.get("title") or "",
  213. "text": row.get("text") or "",
  214. }
  215. for row in result["entities"]
  216. ]
  217. return jsonify(
  218. {
  219. "status_code": 200,
  220. "detail": "success",
  221. "data": {
  222. "entities": entities,
  223. "total_count": result["total_count"],
  224. "page": result["page"],
  225. "page_size": result["page_size"],
  226. "total_pages": result["total_pages"],
  227. },
  228. }
  229. )
  230. async def query_search(
  231. query_text,
  232. filters=None,
  233. search_type="",
  234. anns_field="vector_text",
  235. search_params=BASE_MILVUS_SEARCH_PARAMS,
  236. _source=False,
  237. es_size=10000,
  238. sort_by=None,
  239. milvus_size=20,
  240. limit=10,
  241. ):
  242. if filters is None:
  243. filters = {}
  244. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  245. resource = get_resource_manager()
  246. search_engine = HybridSearch(
  247. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  248. )
  249. try:
  250. match search_type:
  251. case "base":
  252. response = await search_engine.base_vector_search(
  253. query_vec=query_vector,
  254. anns_field=anns_field,
  255. search_params=search_params,
  256. limit=limit,
  257. )
  258. return response
  259. case "hybrid":
  260. response = await search_engine.hybrid_search(
  261. filters=filters,
  262. query_vec=query_vector,
  263. anns_field=anns_field,
  264. search_params=search_params,
  265. es_size=es_size,
  266. sort_by=sort_by,
  267. milvus_size=milvus_size,
  268. )
  269. return response
  270. case "strategy":
  271. return None
  272. case _:
  273. return None
  274. except Exception as e:
  275. return None
  276. @server_bp.route("/query", methods=["GET"])
  277. async def query():
  278. query_text = request.args.get("query")
  279. dataset_ids = request.args.get("datasetIds").split(",")
  280. search_type = request.args.get("search_type", "hybrid")
  281. query_results = await query_search(
  282. query_text=query_text,
  283. filters={"dataset_id": dataset_ids},
  284. search_type=search_type,
  285. )
  286. resource = get_resource_manager()
  287. content_chunk_mapper = ContentChunks(resource.mysql_client)
  288. dataset_mapper = Dataset(resource.mysql_client)
  289. res = []
  290. for result in query_results["results"]:
  291. content_chunks = await content_chunk_mapper.select_chunk_content(
  292. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  293. )
  294. if not content_chunks:
  295. return jsonify(
  296. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  297. )
  298. content_chunk = content_chunks[0]
  299. datasets = await dataset_mapper.select_dataset_by_id(
  300. content_chunk["dataset_id"]
  301. )
  302. if not datasets:
  303. return jsonify(
  304. {"status_code": 500, "detail": "dataset not found", "data": {}}
  305. )
  306. dataset = datasets[0]
  307. dataset_name = None
  308. if dataset:
  309. dataset_name = dataset["name"]
  310. res.append(
  311. {
  312. "docId": content_chunk["doc_id"],
  313. "content": content_chunk["text"],
  314. "contentSummary": content_chunk["summary"],
  315. "score": result["score"],
  316. "datasetName": dataset_name,
  317. }
  318. )
  319. data = {"results": res}
  320. return jsonify({"status_code": 200, "detail": "success", "data": data})
  321. @server_bp.route("/chat", methods=["GET"])
  322. async def chat():
  323. query_text = request.args.get("query")
  324. dataset_ids = request.args.get("datasetIds").split(",")
  325. search_type = request.args.get("search_type", "hybrid")
  326. query_results = await query_search(
  327. query_text=query_text,
  328. filters={"dataset_id": dataset_ids},
  329. search_type=search_type,
  330. )
  331. resource = get_resource_manager()
  332. content_chunk_mapper = ContentChunks(resource.mysql_client)
  333. dataset_mapper = Dataset(resource.mysql_client)
  334. res = []
  335. for result in query_results["results"]:
  336. content_chunks = await content_chunk_mapper.select_chunk_content(
  337. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  338. )
  339. if not content_chunks:
  340. return jsonify(
  341. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  342. )
  343. content_chunk = content_chunks[0]
  344. datasets = await dataset_mapper.select_dataset_by_id(
  345. content_chunk["dataset_id"]
  346. )
  347. if not datasets:
  348. return jsonify(
  349. {"status_code": 500, "detail": "dataset not found", "data": {}}
  350. )
  351. dataset = datasets[0]
  352. dataset_name = None
  353. if dataset:
  354. dataset_name = dataset["name"]
  355. res.append(
  356. {
  357. "docId": content_chunk["doc_id"],
  358. "content": content_chunk["text"],
  359. "contentSummary": content_chunk["summary"],
  360. "score": result["score"],
  361. "datasetName": dataset_name,
  362. }
  363. )
  364. chat_classifier = ChatClassifier()
  365. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  366. data = {"results": res, "chat_res": chat_res}
  367. return jsonify({"status_code": 200, "detail": "success", "data": data})