buleprint.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import asyncio
  2. import json
  3. import traceback
  4. import uuid
  5. from typing import Dict, Any
  6. from quart import Blueprint, jsonify, request
  7. from quart_cors import cors
  8. from applications.api import get_basic_embedding
  9. from applications.api import get_img_embedding
  10. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  11. from applications.config import (
  12. DEFAULT_MODEL,
  13. LOCAL_MODEL_CONFIG,
  14. BASE_MILVUS_SEARCH_PARAMS,
  15. )
  16. from applications.resource import get_resource_manager
  17. from applications.search import HybridSearch
  18. from applications.utils.chat import ChatClassifier
  19. from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult
  20. server_bp = Blueprint("api", __name__, url_prefix="/api")
  21. server_bp = cors(server_bp, allow_origin="*")
  22. @server_bp.route("/embed", methods=["POST"])
  23. async def embed():
  24. body = await request.get_json()
  25. text = body.get("text")
  26. model_name = body.get("model", DEFAULT_MODEL)
  27. if not LOCAL_MODEL_CONFIG.get(model_name):
  28. return jsonify({"error": "error model"})
  29. embedding = await get_basic_embedding(text, model_name)
  30. return jsonify({"embedding": embedding})
  31. @server_bp.route("/img_embed", methods=["POST"])
  32. async def img_embed():
  33. body = await request.get_json()
  34. url_list = body.get("url_list")
  35. if not url_list:
  36. return jsonify({"error": "error url_list"})
  37. embedding = await get_img_embedding(url_list)
  38. return jsonify(embedding)
  39. @server_bp.route("/delete", methods=["POST"])
  40. async def delete():
  41. body = await request.get_json()
  42. level = body.get("level")
  43. params = body.get("params")
  44. if not level or not params:
  45. return jsonify({"error": "error level or params"})
  46. resource = get_resource_manager()
  47. delete_task = DeleteTask(resource)
  48. response = await delete_task.deal(level, params)
  49. return jsonify(response)
  50. @server_bp.route("/chunk", methods=["POST"])
  51. async def chunk():
  52. body = await request.get_json()
  53. text = body.get("text", "")
  54. ori_doc_id = body.get("doc_id")
  55. text = text.strip()
  56. if not text:
  57. return jsonify({"error": "error text"})
  58. resource = get_resource_manager()
  59. # generate doc id
  60. if ori_doc_id:
  61. body["re_chunk"] = True
  62. doc_id = ori_doc_id
  63. else:
  64. doc_id = f"doc-{uuid.uuid4()}"
  65. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  66. doc_id = await chunk_task.deal(body)
  67. return jsonify({"doc_id": doc_id})
  68. @server_bp.route("/search", methods=["POST"])
  69. async def search():
  70. """
  71. filters: Dict[str, Any], # 条件过滤
  72. query_vec: List[float], # query 的向量
  73. anns_field: str = "vector_text", # query指定的向量空间
  74. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  75. query_text: str = None, #是否通过 topic 倒排
  76. _source=False, # 是否返回元数据
  77. es_size: int = 10000, #es 第一层过滤数量
  78. sort_by: str = None, # 排序
  79. milvus_size: int = 10 # milvus粗排返回数量
  80. :return:
  81. """
  82. body = await request.get_json()
  83. # 解析数据
  84. search_type: str = body.get("search_type")
  85. filters: Dict[str, Any] = body.get("filters", {})
  86. anns_field: str = body.get("anns_field", "vector_text")
  87. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  88. query_text: str = body.get("query_text")
  89. _source: bool = body.get("_source", False)
  90. es_size: int = body.get("es_size", 10000)
  91. sort_by: str = body.get("sort_by")
  92. milvus_size: int = body.get("milvus", 20)
  93. limit: int = body.get("limit", 10)
  94. if not query_text:
  95. return jsonify({"error": "error query_text"})
  96. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  97. resource = get_resource_manager()
  98. search_engine = HybridSearch(
  99. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  100. )
  101. try:
  102. match search_type:
  103. case "base":
  104. response = await search_engine.base_vector_search(
  105. query_vec=query_vector,
  106. anns_field=anns_field,
  107. search_params=search_params,
  108. limit=limit,
  109. )
  110. return jsonify(response), 200
  111. case "hybrid":
  112. response = await search_engine.hybrid_search(
  113. filters=filters,
  114. query_vec=query_vector,
  115. anns_field=anns_field,
  116. search_params=search_params,
  117. es_size=es_size,
  118. sort_by=sort_by,
  119. milvus_size=milvus_size,
  120. )
  121. return jsonify(response), 200
  122. case "strategy":
  123. return jsonify({"error": "strategy not implemented"}), 405
  124. case _:
  125. return jsonify({"error": "error search_type"}), 200
  126. except Exception as e:
  127. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  128. @server_bp.route("/dataset/list", methods=["GET"])
  129. async def dataset_list():
  130. resource = get_resource_manager()
  131. datasets = await Dataset(resource.mysql_client).select_dataset()
  132. # 创建所有任务
  133. tasks = [
  134. Contents(resource.mysql_client).select_count(dataset["id"])
  135. for dataset in datasets
  136. ]
  137. counts = await asyncio.gather(*tasks)
  138. # 组装数据
  139. data_list = [
  140. {
  141. "dataset_id": dataset["id"],
  142. "name": dataset["name"],
  143. "count": count,
  144. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  145. }
  146. for dataset, count in zip(datasets, counts)
  147. ]
  148. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  149. @server_bp.route("/dataset/add", methods=["POST"])
  150. async def add_dataset():
  151. resource = get_resource_manager()
  152. dataset = Dataset(resource.mysql_client)
  153. # 从请求体里取参数
  154. body = await request.get_json()
  155. name = body.get("name")
  156. if not name:
  157. return jsonify({"status_code": 400, "detail": "name is required"})
  158. # 执行新增
  159. await dataset.add_dataset(name)
  160. return jsonify({"status_code": 200, "detail": "success"})
  161. @server_bp.route("/content/get", methods=["GET"])
  162. async def get_content():
  163. resource = get_resource_manager()
  164. contents = Contents(resource.mysql_client)
  165. # 获取请求参数
  166. doc_id = request.args.get("docId")
  167. if not doc_id:
  168. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  169. # 查询内容
  170. rows = await contents.select_content_by_doc_id(doc_id)
  171. if not rows:
  172. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  173. row = rows[0]
  174. return jsonify(
  175. {
  176. "status_code": 200,
  177. "detail": "success",
  178. "data": {
  179. "title": row.get("title", ""),
  180. "text": row.get("text", ""),
  181. "doc_id": row.get("doc_id", ""),
  182. },
  183. }
  184. )
  185. @server_bp.route("/content/list", methods=["GET"])
  186. async def content_list():
  187. resource = get_resource_manager()
  188. contents = Contents(resource.mysql_client)
  189. # 从 URL 查询参数获取分页和过滤参数
  190. page_num = int(request.args.get("page", 1))
  191. page_size = int(request.args.get("pageSize", 10))
  192. dataset_id = request.args.get("datasetId")
  193. doc_status = int(request.args.get("doc_status", 1))
  194. # order_by 可以用 JSON 字符串传递
  195. import json
  196. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  197. try:
  198. order_by = json.loads(order_by_str)
  199. except Exception:
  200. order_by = {"id": "desc"}
  201. # 调用 select_contents,获取分页字典
  202. result = await contents.select_contents(
  203. page_num=page_num,
  204. page_size=page_size,
  205. dataset_id=dataset_id,
  206. doc_status=doc_status,
  207. order_by=order_by,
  208. )
  209. # 格式化 entities,只保留必要字段
  210. entities = [
  211. {
  212. "doc_id": row["doc_id"],
  213. "title": row.get("title") or "",
  214. "text": row.get("text") or "",
  215. }
  216. for row in result["entities"]
  217. ]
  218. return jsonify(
  219. {
  220. "status_code": 200,
  221. "detail": "success",
  222. "data": {
  223. "entities": entities,
  224. "total_count": result["total_count"],
  225. "page": result["page"],
  226. "page_size": result["page_size"],
  227. "total_pages": result["total_pages"],
  228. },
  229. }
  230. )
  231. async def query_search(
  232. query_text,
  233. filters=None,
  234. search_type="",
  235. anns_field="vector_text",
  236. search_params=BASE_MILVUS_SEARCH_PARAMS,
  237. _source=False,
  238. es_size=10000,
  239. sort_by=None,
  240. milvus_size=20,
  241. limit=10,
  242. ):
  243. if filters is None:
  244. filters = {}
  245. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  246. resource = get_resource_manager()
  247. search_engine = HybridSearch(
  248. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  249. )
  250. try:
  251. match search_type:
  252. case "base":
  253. response = await search_engine.base_vector_search(
  254. query_vec=query_vector,
  255. anns_field=anns_field,
  256. search_params=search_params,
  257. limit=limit,
  258. )
  259. return response
  260. case "hybrid":
  261. response = await search_engine.hybrid_search(
  262. filters=filters,
  263. query_vec=query_vector,
  264. anns_field=anns_field,
  265. search_params=search_params,
  266. es_size=es_size,
  267. sort_by=sort_by,
  268. milvus_size=milvus_size,
  269. )
  270. return response
  271. case "strategy":
  272. return None
  273. case _:
  274. return None
  275. except Exception as e:
  276. return None
  277. @server_bp.route("/query", methods=["GET"])
  278. async def query():
  279. query_text = request.args.get("query")
  280. dataset_ids = request.args.get("datasetIds").split(",")
  281. search_type = request.args.get("search_type", "hybrid")
  282. query_results = await query_search(
  283. query_text=query_text,
  284. filters={"dataset_id": dataset_ids},
  285. search_type=search_type,
  286. )
  287. resource = get_resource_manager()
  288. content_chunk_mapper = ContentChunks(resource.mysql_client)
  289. dataset_mapper = Dataset(resource.mysql_client)
  290. res = []
  291. for result in query_results["results"]:
  292. content_chunks = await content_chunk_mapper.select_chunk_content(
  293. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  294. )
  295. if not content_chunks:
  296. return jsonify(
  297. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  298. )
  299. content_chunk = content_chunks[0]
  300. datasets = await dataset_mapper.select_dataset_by_id(
  301. content_chunk["dataset_id"]
  302. )
  303. if not datasets:
  304. return jsonify(
  305. {"status_code": 500, "detail": "dataset not found", "data": {}}
  306. )
  307. dataset = datasets[0]
  308. dataset_name = None
  309. if dataset:
  310. dataset_name = dataset["name"]
  311. res.append(
  312. {
  313. "docId": content_chunk["doc_id"],
  314. "content": content_chunk["text"],
  315. "contentSummary": content_chunk["summary"],
  316. "score": result["score"],
  317. "datasetName": dataset_name,
  318. }
  319. )
  320. data = {"results": res}
  321. return jsonify({"status_code": 200, "detail": "success", "data": data})
  322. @server_bp.route("/chat", methods=["GET"])
  323. async def chat():
  324. query_text = request.args.get("query")
  325. dataset_id_strs = request.args.get("datasetIds")
  326. dataset_ids = dataset_id_strs.split(",")
  327. search_type = request.args.get("search_type", "hybrid")
  328. query_results = await query_search(
  329. query_text=query_text,
  330. filters={"dataset_id": dataset_ids},
  331. search_type=search_type,
  332. )
  333. resource = get_resource_manager()
  334. content_chunk_mapper = ContentChunks(resource.mysql_client)
  335. dataset_mapper = Dataset(resource.mysql_client)
  336. chat_result_mapper = ChatResult(resource.mysql_client)
  337. res = []
  338. for result in query_results["results"]:
  339. content_chunks = await content_chunk_mapper.select_chunk_content(
  340. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  341. )
  342. if not content_chunks:
  343. return jsonify(
  344. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  345. )
  346. content_chunk = content_chunks[0]
  347. datasets = await dataset_mapper.select_dataset_by_id(
  348. content_chunk["dataset_id"]
  349. )
  350. if not datasets:
  351. return jsonify(
  352. {"status_code": 500, "detail": "dataset not found", "data": {}}
  353. )
  354. dataset = datasets[0]
  355. dataset_name = None
  356. if dataset:
  357. dataset_name = dataset["name"]
  358. res.append(
  359. {
  360. "docId": content_chunk["doc_id"],
  361. "content": content_chunk["text"],
  362. "contentSummary": content_chunk["summary"],
  363. "score": result["score"],
  364. "datasetName": dataset_name,
  365. }
  366. )
  367. chat_classifier = ChatClassifier()
  368. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  369. data = {"results": res, "chat_res": chat_res["summary"]}
  370. await chat_result_mapper.insert_chat_result(
  371. query_text,
  372. dataset_id_strs,
  373. json.dumps(data, ensure_ascii=False),
  374. chat_res["summary"],
  375. chat_res["relevance_score"],
  376. )
  377. return jsonify({"status_code": 200, "detail": "success", "data": data})
  378. @server_bp.route("/chunk/list", methods=["GET"])
  379. async def chunk_list():
  380. resource = get_resource_manager()
  381. content_chunk = ContentChunks(resource.mysql_client)
  382. # 从 URL 查询参数获取分页和过滤参数
  383. page_num = int(request.args.get("page", 1))
  384. page_size = int(request.args.get("pageSize", 10))
  385. doc_id = request.args.get("docId")
  386. if not doc_id:
  387. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  388. # 调用 select_contents,获取分页字典
  389. result = await content_chunk.select_chunk_contents(
  390. page_num=page_num, page_size=page_size, doc_id=doc_id
  391. )
  392. if not result:
  393. return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
  394. # 格式化 entities,只保留必要字段
  395. entities = [
  396. {
  397. "id": row["id"],
  398. "chunk_id": row["chunk_id"],
  399. "doc_id": row["doc_id"],
  400. "summary": row.get("summary") or "",
  401. "text": row.get("text") or "",
  402. }
  403. for row in result["entities"]
  404. ]
  405. return jsonify(
  406. {
  407. "status_code": 200,
  408. "detail": "success",
  409. "data": {
  410. "entities": entities,
  411. "total_count": result["total_count"],
  412. "page": result["page"],
  413. "page_size": result["page_size"],
  414. "total_pages": result["total_pages"],
  415. },
  416. }
  417. )