buleprint.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. import asyncio
  2. import json
  3. import traceback
  4. import uuid
  5. from typing import Dict, Any
  6. from quart import Blueprint, jsonify, request
  7. from quart_cors import cors
  8. from applications.config import (
  9. DEFAULT_MODEL,
  10. LOCAL_MODEL_CONFIG,
  11. BASE_MILVUS_SEARCH_PARAMS,
  12. )
  13. from applications.resource import get_resource_manager
  14. from applications.api import get_basic_embedding
  15. from applications.api import get_img_embedding
  16. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  17. from applications.search import HybridSearch
  18. from applications.utils.chat import ChatClassifier
  19. from applications.utils.mysql import Dataset, Contents, ContentChunks
  20. from applications.utils.mysql.mapper import ChatRes
  21. server_bp = Blueprint("api", __name__, url_prefix="/api")
  22. server_bp = cors(server_bp, allow_origin="*")
  23. @server_bp.route("/embed", methods=["POST"])
  24. async def embed():
  25. body = await request.get_json()
  26. text = body.get("text")
  27. model_name = body.get("model", DEFAULT_MODEL)
  28. if not LOCAL_MODEL_CONFIG.get(model_name):
  29. return jsonify({"error": "error model"})
  30. embedding = await get_basic_embedding(text, model_name)
  31. return jsonify({"embedding": embedding})
  32. @server_bp.route("/img_embed", methods=["POST"])
  33. async def img_embed():
  34. body = await request.get_json()
  35. url_list = body.get("url_list")
  36. if not url_list:
  37. return jsonify({"error": "error url_list"})
  38. embedding = await get_img_embedding(url_list)
  39. return jsonify(embedding)
  40. @server_bp.route("/delete", methods=["POST"])
  41. async def delete():
  42. body = await request.get_json()
  43. level = body.get("level")
  44. params = body.get("params")
  45. if not level or not params:
  46. return jsonify({"error": "error level or params"})
  47. resource = get_resource_manager()
  48. delete_task = DeleteTask(resource)
  49. response = await delete_task.deal(level, params)
  50. return jsonify(response)
  51. @server_bp.route("/chunk", methods=["POST"])
  52. async def chunk():
  53. body = await request.get_json()
  54. text = body.get("text", "")
  55. ori_doc_id = body.get("doc_id")
  56. text = text.strip()
  57. if not text:
  58. return jsonify({"error": "error text"})
  59. resource = get_resource_manager()
  60. # generate doc id
  61. if ori_doc_id:
  62. body["re_chunk"] = True
  63. doc_id = ori_doc_id
  64. else:
  65. doc_id = f"doc-{uuid.uuid4()}"
  66. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  67. doc_id = await chunk_task.deal(body)
  68. return jsonify({"doc_id": doc_id})
  69. @server_bp.route("/search", methods=["POST"])
  70. async def search():
  71. """
  72. filters: Dict[str, Any], # 条件过滤
  73. query_vec: List[float], # query 的向量
  74. anns_field: str = "vector_text", # query指定的向量空间
  75. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  76. query_text: str = None, #是否通过 topic 倒排
  77. _source=False, # 是否返回元数据
  78. es_size: int = 10000, #es 第一层过滤数量
  79. sort_by: str = None, # 排序
  80. milvus_size: int = 10 # milvus粗排返回数量
  81. :return:
  82. """
  83. body = await request.get_json()
  84. # 解析数据
  85. search_type: str = body.get("search_type")
  86. filters: Dict[str, Any] = body.get("filters", {})
  87. anns_field: str = body.get("anns_field", "vector_text")
  88. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  89. query_text: str = body.get("query_text")
  90. _source: bool = body.get("_source", False)
  91. es_size: int = body.get("es_size", 10000)
  92. sort_by: str = body.get("sort_by")
  93. milvus_size: int = body.get("milvus", 20)
  94. limit: int = body.get("limit", 10)
  95. if not query_text:
  96. return jsonify({"error": "error query_text"})
  97. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  98. resource = get_resource_manager()
  99. search_engine = HybridSearch(
  100. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  101. )
  102. try:
  103. match search_type:
  104. case "base":
  105. response = await search_engine.base_vector_search(
  106. query_vec=query_vector,
  107. anns_field=anns_field,
  108. search_params=search_params,
  109. limit=limit,
  110. )
  111. return jsonify(response), 200
  112. case "hybrid":
  113. response = await search_engine.hybrid_search(
  114. filters=filters,
  115. query_vec=query_vector,
  116. anns_field=anns_field,
  117. search_params=search_params,
  118. es_size=es_size,
  119. sort_by=sort_by,
  120. milvus_size=milvus_size,
  121. )
  122. return jsonify(response), 200
  123. case "strategy":
  124. return jsonify({"error": "strategy not implemented"}), 405
  125. case _:
  126. return jsonify({"error": "error search_type"}), 200
  127. except Exception as e:
  128. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  129. @server_bp.route("/dataset/list", methods=["GET"])
  130. async def dataset_list():
  131. resource = get_resource_manager()
  132. datasets = await Dataset(resource.mysql_client).select_dataset()
  133. # 创建所有任务
  134. tasks = [
  135. Contents(resource.mysql_client).select_count(dataset["id"])
  136. for dataset in datasets
  137. ]
  138. counts = await asyncio.gather(*tasks)
  139. # 组装数据
  140. data_list = [
  141. {
  142. "dataset_id": dataset["id"],
  143. "name": dataset["name"],
  144. "count": count,
  145. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  146. }
  147. for dataset, count in zip(datasets, counts)
  148. ]
  149. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  150. @server_bp.route("/dataset/add", methods=["POST"])
  151. async def add_dataset():
  152. resource = get_resource_manager()
  153. dataset = Dataset(resource.mysql_client)
  154. # 从请求体里取参数
  155. body = await request.get_json()
  156. name = body.get("name")
  157. if not name:
  158. return jsonify({"status_code": 400, "detail": "name is required"})
  159. # 执行新增
  160. await dataset.add_dataset(name)
  161. return jsonify({"status_code": 200, "detail": "success"})
  162. @server_bp.route("/content/get", methods=["GET"])
  163. async def get_content():
  164. resource = get_resource_manager()
  165. contents = Contents(resource.mysql_client)
  166. # 获取请求参数
  167. doc_id = request.args.get("docId")
  168. if not doc_id:
  169. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  170. # 查询内容
  171. rows = await contents.select_content_by_doc_id(doc_id)
  172. if not rows:
  173. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  174. row = rows[0]
  175. return jsonify({
  176. "status_code": 200,
  177. "detail": "success",
  178. "data": {
  179. "title": row.get("title", ""),
  180. "text": row.get("text", ""),
  181. "doc_id": row.get("doc_id", "")
  182. }
  183. })
  184. @server_bp.route("/content/list", methods=["GET"])
  185. async def content_list():
  186. resource = get_resource_manager()
  187. contents = Contents(resource.mysql_client)
  188. # 从 URL 查询参数获取分页和过滤参数
  189. page_num = int(request.args.get("page", 1))
  190. page_size = int(request.args.get("pageSize", 10))
  191. dataset_id = request.args.get("datasetId")
  192. doc_status = int(request.args.get("doc_status", 1))
  193. # order_by 可以用 JSON 字符串传递
  194. import json
  195. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  196. try:
  197. order_by = json.loads(order_by_str)
  198. except Exception:
  199. order_by = {"id": "desc"}
  200. # 调用 select_contents,获取分页字典
  201. result = await contents.select_contents(
  202. page_num=page_num,
  203. page_size=page_size,
  204. dataset_id=dataset_id,
  205. doc_status=doc_status,
  206. order_by=order_by,
  207. )
  208. # 格式化 entities,只保留必要字段
  209. entities = [
  210. {
  211. "doc_id": row["doc_id"],
  212. "title": row.get("title") or "",
  213. "text": row.get("text") or "",
  214. }
  215. for row in result["entities"]
  216. ]
  217. return jsonify(
  218. {
  219. "status_code": 200,
  220. "detail": "success",
  221. "data": {
  222. "entities": entities,
  223. "total_count": result["total_count"],
  224. "page": result["page"],
  225. "page_size": result["page_size"],
  226. "total_pages": result["total_pages"],
  227. },
  228. }
  229. )
  230. async def query_search(
  231. query_text,
  232. filters=None,
  233. search_type="",
  234. anns_field="vector_text",
  235. search_params=BASE_MILVUS_SEARCH_PARAMS,
  236. _source=False,
  237. es_size=10000,
  238. sort_by=None,
  239. milvus_size=20,
  240. limit=10,
  241. ):
  242. if filters is None:
  243. filters = {}
  244. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  245. resource = get_resource_manager()
  246. search_engine = HybridSearch(
  247. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  248. )
  249. try:
  250. match search_type:
  251. case "base":
  252. response = await search_engine.base_vector_search(
  253. query_vec=query_vector,
  254. anns_field=anns_field,
  255. search_params=search_params,
  256. limit=limit,
  257. )
  258. return response
  259. case "hybrid":
  260. response = await search_engine.hybrid_search(
  261. filters=filters,
  262. query_vec=query_vector,
  263. anns_field=anns_field,
  264. search_params=search_params,
  265. es_size=es_size,
  266. sort_by=sort_by,
  267. milvus_size=milvus_size,
  268. )
  269. return response
  270. case "strategy":
  271. return None
  272. case _:
  273. return None
  274. except Exception as e:
  275. return None
  276. @server_bp.route("/query", methods=["GET"])
  277. async def query():
  278. query_text = request.args.get("query")
  279. dataset_ids = request.args.get("datasetIds").split(",")
  280. search_type = request.args.get("search_type", "hybrid")
  281. query_results = await query_search(
  282. query_text=query_text,
  283. filters={"dataset_id": dataset_ids},
  284. search_type=search_type,
  285. )
  286. resource = get_resource_manager()
  287. content_chunk_mapper = ContentChunks(resource.mysql_client)
  288. dataset_mapper = Dataset(resource.mysql_client)
  289. res = []
  290. for result in query_results["results"]:
  291. content_chunks = await content_chunk_mapper.select_chunk_content(
  292. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  293. )
  294. if not content_chunks:
  295. return jsonify(
  296. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  297. )
  298. content_chunk = content_chunks[0]
  299. datasets = await dataset_mapper.select_dataset_by_id(
  300. content_chunk["dataset_id"]
  301. )
  302. if not datasets:
  303. return jsonify(
  304. {"status_code": 500, "detail": "dataset not found", "data": {}}
  305. )
  306. dataset = datasets[0]
  307. dataset_name = None
  308. if dataset:
  309. dataset_name = dataset["name"]
  310. res.append(
  311. {
  312. "docId": content_chunk["doc_id"],
  313. "content": content_chunk["text"],
  314. "contentSummary": content_chunk["summary"],
  315. "score": result["score"],
  316. "datasetName": dataset_name,
  317. }
  318. )
  319. data = {"results": res}
  320. return jsonify({"status_code": 200, "detail": "success", "data": data})
  321. @server_bp.route("/chat", methods=["GET"])
  322. async def chat():
  323. query_text = request.args.get("query")
  324. dataset_id_strs = request.args.get("datasetIds")
  325. dataset_ids = dataset_id_strs.split(",")
  326. search_type = request.args.get("search_type", "hybrid")
  327. query_results = await query_search(
  328. query_text=query_text,
  329. filters={"dataset_id": dataset_ids},
  330. search_type=search_type,
  331. )
  332. resource = get_resource_manager()
  333. content_chunk_mapper = ContentChunks(resource.mysql_client)
  334. dataset_mapper = Dataset(resource.mysql_client)
  335. chat_res_mapper = ChatRes(resource.mysql_client)
  336. res = []
  337. for result in query_results["results"]:
  338. content_chunks = await content_chunk_mapper.select_chunk_content(
  339. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  340. )
  341. if not content_chunks:
  342. return jsonify(
  343. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  344. )
  345. content_chunk = content_chunks[0]
  346. datasets = await dataset_mapper.select_dataset_by_id(
  347. content_chunk["dataset_id"]
  348. )
  349. if not datasets:
  350. return jsonify(
  351. {"status_code": 500, "detail": "dataset not found", "data": {}}
  352. )
  353. dataset = datasets[0]
  354. dataset_name = None
  355. if dataset:
  356. dataset_name = dataset["name"]
  357. res.append(
  358. {
  359. "docId": content_chunk["doc_id"],
  360. "content": content_chunk["text"],
  361. "contentSummary": content_chunk["summary"],
  362. "score": result["score"],
  363. "datasetName": dataset_name,
  364. }
  365. )
  366. chat_classifier = ChatClassifier()
  367. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  368. data = {"results": res, "chat_res": chat_res["summary"]}
  369. await chat_res_mapper.insert_chat_res(
  370. query_text,
  371. dataset_id_strs,
  372. json.dumps(data, ensure_ascii=False),
  373. chat_res["summary"],
  374. chat_res["relevance_score"],
  375. )
  376. return jsonify({"status_code": 200, "detail": "success", "data": data})
  377. @server_bp.route("/chunk/list", methods=["GET"])
  378. async def chunk_list():
  379. resource = get_resource_manager()
  380. content_chunk = ContentChunks(resource.mysql_client)
  381. # 从 URL 查询参数获取分页和过滤参数
  382. page_num = int(request.args.get("page", 1))
  383. page_size = int(request.args.get("pageSize", 10))
  384. doc_id = request.args.get("docId")
  385. if not doc_id:
  386. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  387. # 调用 select_contents,获取分页字典
  388. result = await content_chunk.select_chunk_contents(
  389. page_num=page_num, page_size=page_size, doc_id=doc_id
  390. )
  391. if not result:
  392. return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
  393. # 格式化 entities,只保留必要字段
  394. entities = [
  395. {
  396. "id": row["id"],
  397. "chunk_id": row["chunk_id"],
  398. "doc_id": row["doc_id"],
  399. "summary": row.get("summary") or "",
  400. "text": row.get("text") or "",
  401. }
  402. for row in result["entities"]
  403. ]
  404. return jsonify(
  405. {
  406. "status_code": 200,
  407. "detail": "success",
  408. "data": {
  409. "entities": entities,
  410. "total_count": result["total_count"],
  411. "page": result["page"],
  412. "page_size": result["page_size"],
  413. "total_pages": result["total_pages"],
  414. },
  415. }
  416. )