buleprint.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. import asyncio
  2. import json
  3. import traceback
  4. import uuid
  5. from typing import Dict, Any
  6. from quart import Blueprint, jsonify, request
  7. from quart_cors import cors
  8. from applications.api import get_basic_embedding
  9. from applications.api import get_img_embedding
  10. from applications.async_task import AutoRechunkTask, BuildGraph
  11. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  12. from applications.config import (
  13. DEFAULT_MODEL,
  14. LOCAL_MODEL_CONFIG,
  15. BASE_MILVUS_SEARCH_PARAMS,
  16. )
  17. from applications.resource import get_resource_manager
  18. from applications.search import HybridSearch
  19. from applications.utils.chat import RAGChatAgent
  20. from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult
  21. server_bp = Blueprint("api", __name__, url_prefix="/api")
  22. server_bp = cors(server_bp, allow_origin="*")
  23. @server_bp.route("/embed", methods=["POST"])
  24. async def embed():
  25. body = await request.get_json()
  26. text = body.get("text")
  27. model_name = body.get("model", DEFAULT_MODEL)
  28. if not LOCAL_MODEL_CONFIG.get(model_name):
  29. return jsonify({"error": "error model"})
  30. embedding = await get_basic_embedding(text, model_name)
  31. return jsonify({"embedding": embedding})
  32. @server_bp.route("/img_embed", methods=["POST"])
  33. async def img_embed():
  34. body = await request.get_json()
  35. url_list = body.get("url_list")
  36. if not url_list:
  37. return jsonify({"error": "error url_list"})
  38. embedding = await get_img_embedding(url_list)
  39. return jsonify(embedding)
  40. @server_bp.route("/delete", methods=["POST"])
  41. async def delete():
  42. body = await request.get_json()
  43. level = body.get("level")
  44. params = body.get("params")
  45. if not level or not params:
  46. return jsonify({"error": "error level or params"})
  47. resource = get_resource_manager()
  48. del_task = DeleteTask(resource)
  49. response = await del_task.deal(level, params)
  50. return jsonify(response)
  51. @server_bp.route("/chunk", methods=["POST"])
  52. async def chunk():
  53. body = await request.get_json()
  54. text = body.get("text", "")
  55. ori_doc_id = body.get("doc_id")
  56. text = text.strip()
  57. if not text:
  58. return jsonify({"error": "error text"})
  59. resource = get_resource_manager()
  60. # generate doc id
  61. if ori_doc_id:
  62. body["re_chunk"] = True
  63. doc_id = ori_doc_id
  64. else:
  65. doc_id = f"doc-{uuid.uuid4()}"
  66. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  67. doc_id = await chunk_task.deal(body)
  68. return jsonify({"doc_id": doc_id})
  69. @server_bp.route("/search", methods=["POST"])
  70. async def search():
  71. """
  72. filters: Dict[str, Any], # 条件过滤
  73. query_vec: List[float], # query 的向量
  74. anns_field: str = "vector_text", # query指定的向量空间
  75. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  76. query_text: str = None, #是否通过 topic 倒排
  77. _source=False, # 是否返回元数据
  78. es_size: int = 10000, #es 第一层过滤数量
  79. sort_by: str = None, # 排序
  80. milvus_size: int = 10 # milvus粗排返回数量
  81. :return:
  82. """
  83. body = await request.get_json()
  84. # 解析数据
  85. search_type: str = body.get("search_type")
  86. filters: Dict[str, Any] = body.get("filters", {})
  87. anns_field: str = body.get("anns_field", "vector_text")
  88. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  89. query_text: str = body.get("query_text")
  90. _source: bool = body.get("_source", False)
  91. es_size: int = body.get("es_size", 10000)
  92. sort_by: str = body.get("sort_by")
  93. milvus_size: int = body.get("milvus", 20)
  94. limit: int = body.get("limit", 10)
  95. path_between_chunks: dict = body.get("path_between_chunks", {})
  96. if not query_text:
  97. return jsonify({"error": "error query_text"})
  98. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  99. resource = get_resource_manager()
  100. search_engine = HybridSearch(
  101. milvus_pool=resource.milvus_client,
  102. es_pool=resource.es_client,
  103. graph_pool=resource.graph_client,
  104. )
  105. try:
  106. match search_type:
  107. case "base":
  108. response = await search_engine.base_vector_search(
  109. query_vec=query_vector,
  110. anns_field=anns_field,
  111. search_params=search_params,
  112. limit=limit,
  113. )
  114. return jsonify(response), 200
  115. case "hybrid":
  116. response = await search_engine.hybrid_search(
  117. filters=filters,
  118. query_vec=query_vector,
  119. anns_field=anns_field,
  120. search_params=search_params,
  121. es_size=es_size,
  122. sort_by=sort_by,
  123. milvus_size=milvus_size,
  124. )
  125. return jsonify(response), 200
  126. case "hybrid2":
  127. co_fields = {"Entity": filters["entities"][0]}
  128. response = await search_engine.hybrid_search_with_graph(
  129. filters=filters,
  130. query_vec=query_vector,
  131. anns_field=anns_field,
  132. search_params=search_params,
  133. es_size=es_size,
  134. sort_by=sort_by,
  135. milvus_size=milvus_size,
  136. co_occurrence_fields=co_fields,
  137. shortest_path_fields=path_between_chunks,
  138. )
  139. return jsonify(response), 200
  140. case "strategy":
  141. return jsonify({"error": "strategy not implemented"}), 405
  142. case _:
  143. return jsonify({"error": "error search_type"}), 200
  144. except Exception as e:
  145. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  146. @server_bp.route("/dataset/list", methods=["GET"])
  147. async def dataset_list():
  148. resource = get_resource_manager()
  149. datasets = await Dataset(resource.mysql_client).select_dataset()
  150. # 创建所有任务
  151. tasks = [
  152. Contents(resource.mysql_client).select_count(dataset["id"])
  153. for dataset in datasets
  154. ]
  155. counts = await asyncio.gather(*tasks)
  156. # 组装数据
  157. data_list = [
  158. {
  159. "dataset_id": dataset["id"],
  160. "name": dataset["name"],
  161. "count": count,
  162. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  163. }
  164. for dataset, count in zip(datasets, counts)
  165. ]
  166. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  167. @server_bp.route("/dataset/add", methods=["POST"])
  168. async def add_dataset():
  169. resource = get_resource_manager()
  170. dataset_mapper = Dataset(resource.mysql_client)
  171. # 从请求体里取参数
  172. body = await request.get_json()
  173. name = body.get("name")
  174. if not name:
  175. return jsonify({"status_code": 400, "detail": "name is required"})
  176. # 执行新增
  177. dataset = await dataset_mapper.select_dataset_by_name(name)
  178. if dataset:
  179. return jsonify({"status_code": 400, "detail": "name is exist"})
  180. await dataset_mapper.add_dataset(name)
  181. new_dataset = await dataset_mapper.select_dataset_by_name(name)
  182. return jsonify(
  183. {
  184. "status_code": 200,
  185. "detail": "success",
  186. "data": {"datasetId": new_dataset[0]["id"]},
  187. }
  188. )
  189. @server_bp.route("/content/get", methods=["GET"])
  190. async def get_content():
  191. resource = get_resource_manager()
  192. contents = Contents(resource.mysql_client)
  193. # 获取请求参数
  194. doc_id = request.args.get("docId")
  195. if not doc_id:
  196. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  197. # 查询内容
  198. rows = await contents.select_content_by_doc_id(doc_id)
  199. if not rows:
  200. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  201. row = rows[0]
  202. return jsonify(
  203. {
  204. "status_code": 200,
  205. "detail": "success",
  206. "data": {
  207. "title": row.get("title", ""),
  208. "text": row.get("text", ""),
  209. "doc_id": row.get("doc_id", ""),
  210. },
  211. }
  212. )
  213. @server_bp.route("/content/list", methods=["GET"])
  214. async def content_list():
  215. resource = get_resource_manager()
  216. contents = Contents(resource.mysql_client)
  217. # 从 URL 查询参数获取分页和过滤参数
  218. page_num = int(request.args.get("page", 1))
  219. page_size = int(request.args.get("pageSize", 10))
  220. dataset_id = request.args.get("datasetId")
  221. doc_status = int(request.args.get("doc_status", 1))
  222. # order_by 可以用 JSON 字符串传递
  223. import json
  224. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  225. try:
  226. order_by = json.loads(order_by_str)
  227. except Exception:
  228. order_by = {"id": "desc"}
  229. # 调用 select_contents,获取分页字典
  230. result = await contents.select_contents(
  231. page_num=page_num,
  232. page_size=page_size,
  233. dataset_id=dataset_id,
  234. doc_status=doc_status,
  235. order_by=order_by,
  236. )
  237. # 格式化 entities,只保留必要字段
  238. entities = [
  239. {
  240. "doc_id": row["doc_id"],
  241. "title": row.get("title") or "",
  242. "text": row.get("text") or "",
  243. "statusDesc": "可用" if row.get("status") == 2 else "不可用",
  244. }
  245. for row in result["entities"]
  246. ]
  247. return jsonify(
  248. {
  249. "status_code": 200,
  250. "detail": "success",
  251. "data": {
  252. "entities": entities,
  253. "total_count": result["total_count"],
  254. "page": result["page"],
  255. "page_size": result["page_size"],
  256. "total_pages": result["total_pages"],
  257. },
  258. }
  259. )
  260. async def query_search(
  261. query_text,
  262. filters=None,
  263. search_type="",
  264. anns_field="vector_text",
  265. search_params=BASE_MILVUS_SEARCH_PARAMS,
  266. _source=False,
  267. es_size=10000,
  268. sort_by=None,
  269. milvus_size=20,
  270. limit=10,
  271. ):
  272. if filters is None:
  273. filters = {}
  274. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  275. resource = get_resource_manager()
  276. search_engine = HybridSearch(
  277. milvus_pool=resource.milvus_client,
  278. es_pool=resource.es_client,
  279. graph_pool=resource.graph_client,
  280. )
  281. try:
  282. match search_type:
  283. case "base":
  284. response = await search_engine.base_vector_search(
  285. query_vec=query_vector,
  286. anns_field=anns_field,
  287. search_params=search_params,
  288. limit=limit,
  289. )
  290. return response
  291. case "hybrid":
  292. response = await search_engine.hybrid_search(
  293. filters=filters,
  294. query_vec=query_vector,
  295. anns_field=anns_field,
  296. search_params=search_params,
  297. es_size=es_size,
  298. sort_by=sort_by,
  299. milvus_size=milvus_size,
  300. )
  301. case "strategy":
  302. return None
  303. case _:
  304. return None
  305. except Exception as e:
  306. return None
  307. if response is None:
  308. return None
  309. resource = get_resource_manager()
  310. content_chunk_mapper = ContentChunks(resource.mysql_client)
  311. res = []
  312. for result in response["results"]:
  313. content_chunks = await content_chunk_mapper.select_chunk_content(
  314. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  315. )
  316. if content_chunks:
  317. content_chunk = content_chunks[0]
  318. res.append(
  319. {
  320. "docId": content_chunk["doc_id"],
  321. "content": content_chunk["text"],
  322. "contentSummary": content_chunk["summary"],
  323. "score": result["score"],
  324. "datasetId": content_chunk["dataset_id"],
  325. }
  326. )
  327. return res
  328. @server_bp.route("/query", methods=["GET"])
  329. async def query():
  330. query_text = request.args.get("query")
  331. dataset_ids = request.args.get("datasetIds").split(",")
  332. search_type = request.args.get("search_type", "hybrid")
  333. query_results = await query_search(
  334. query_text=query_text,
  335. filters={"dataset_id": dataset_ids},
  336. search_type=search_type,
  337. )
  338. resource = get_resource_manager()
  339. dataset_mapper = Dataset(resource.mysql_client)
  340. for result in query_results:
  341. datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
  342. if datasets:
  343. dataset_name = datasets[0]["name"]
  344. result["datasetName"] = dataset_name
  345. data = {"results": query_results}
  346. return jsonify({"status_code": 200, "detail": "success", "data": data})
  347. @server_bp.route("/chat", methods=["GET"])
  348. async def chat():
  349. query_text = request.args.get("query")
  350. dataset_id_strs = request.args.get("datasetIds")
  351. dataset_ids = dataset_id_strs.split(",")
  352. search_type = request.args.get("search_type", "hybrid")
  353. query_results = await query_search(
  354. query_text=query_text,
  355. filters={"dataset_id": dataset_ids},
  356. search_type=search_type,
  357. )
  358. resource = get_resource_manager()
  359. chat_result_mapper = ChatResult(resource.mysql_client)
  360. resource = get_resource_manager()
  361. dataset_mapper = Dataset(resource.mysql_client)
  362. for result in query_results:
  363. datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
  364. if datasets:
  365. dataset_name = datasets[0]["name"]
  366. result["datasetName"] = dataset_name
  367. rag_chat_agent = RAGChatAgent()
  368. chat_res = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
  369. deepseek_search = await rag_chat_agent.search_with_deepseek(query_text)
  370. select = await rag_chat_agent.select_with_deepseek(chat_res, deepseek_search)
  371. data = {"results": query_results, "chat_res": select["result"]}
  372. await chat_result_mapper.insert_chat_result(
  373. query_text,
  374. dataset_id_strs,
  375. json.dumps(data, ensure_ascii=False),
  376. chat_res["summary"],
  377. chat_res["relevance_score"],
  378. chat_res["status"],
  379. deepseek_search["answer"],
  380. deepseek_search["source"],
  381. deepseek_search["status"],
  382. )
  383. return jsonify({"status_code": 200, "detail": "success", "data": data})
  384. @server_bp.route("/chunk/list", methods=["GET"])
  385. async def chunk_list():
  386. resource = get_resource_manager()
  387. content_chunk = ContentChunks(resource.mysql_client)
  388. # 从 URL 查询参数获取分页和过滤参数
  389. page_num = int(request.args.get("page", 1))
  390. page_size = int(request.args.get("pageSize", 10))
  391. doc_id = request.args.get("docId")
  392. if not doc_id:
  393. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  394. # 调用 select_contents,获取分页字典
  395. result = await content_chunk.select_chunk_contents(
  396. page_num=page_num, page_size=page_size, doc_id=doc_id
  397. )
  398. if not result:
  399. return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
  400. # 格式化 entities,只保留必要字段
  401. entities = [
  402. {
  403. "id": row["id"],
  404. "chunk_id": row["chunk_id"],
  405. "doc_id": row["doc_id"],
  406. "summary": row.get("summary") or "",
  407. "text": row.get("text") or "",
  408. "statusDesc": "可用" if row.get("chunk_status") == 2 else "不可用",
  409. }
  410. for row in result["entities"]
  411. ]
  412. return jsonify(
  413. {
  414. "status_code": 200,
  415. "detail": "success",
  416. "data": {
  417. "entities": entities,
  418. "total_count": result["total_count"],
  419. "page": result["page"],
  420. "page_size": result["page_size"],
  421. "total_pages": result["total_pages"],
  422. },
  423. }
  424. )
  425. @server_bp.route("/auto_rechunk", methods=["GET"])
  426. async def auto_rechunk():
  427. resource = get_resource_manager()
  428. auto_rechunk_task = AutoRechunkTask(mysql_client=resource.mysql_client)
  429. process_cnt = await auto_rechunk_task.deal()
  430. return jsonify({"status_code": 200, "detail": "success", "cnt": process_cnt})
  431. @server_bp.route("/build_graph", methods=["POST"])
  432. async def delete_task():
  433. body = await request.get_json()
  434. doc_id: str = body.get("doc_id")
  435. if not doc_id:
  436. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  437. resource = get_resource_manager()
  438. build_graph_task = BuildGraph(
  439. neo4j=resource.graph_client,
  440. es_client=resource.es_client,
  441. mysql_client=resource.mysql_client,
  442. )
  443. await build_graph_task.deal(doc_id)
  444. return jsonify({"status_code": 200, "detail": "success", "data": {}})