buleprint.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. import asyncio
  2. import json
  3. import traceback
  4. import uuid
  5. from typing import Dict, Any
  6. from quart import Blueprint, jsonify, request
  7. from quart_cors import cors
  8. from applications.api import get_basic_embedding
  9. from applications.api import get_img_embedding
  10. from applications.async_task import AutoRechunkTask, BuildGraph
  11. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  12. from applications.config import (
  13. DEFAULT_MODEL,
  14. LOCAL_MODEL_CONFIG,
  15. BASE_MILVUS_SEARCH_PARAMS,
  16. )
  17. from applications.resource import get_resource_manager
  18. from applications.search import HybridSearch
  19. from applications.utils.chat import RAGChatAgent
  20. from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult
  21. from applications.utils.spider.study import study
  22. server_bp = Blueprint("api", __name__, url_prefix="/api")
  23. server_bp = cors(server_bp, allow_origin="*")
  24. @server_bp.route("/embed", methods=["POST"])
  25. async def embed():
  26. body = await request.get_json()
  27. text = body.get("text")
  28. model_name = body.get("model", DEFAULT_MODEL)
  29. if not LOCAL_MODEL_CONFIG.get(model_name):
  30. return jsonify({"error": "error model"})
  31. embedding = await get_basic_embedding(text, model_name)
  32. return jsonify({"embedding": embedding})
  33. @server_bp.route("/img_embed", methods=["POST"])
  34. async def img_embed():
  35. body = await request.get_json()
  36. url_list = body.get("url_list")
  37. if not url_list:
  38. return jsonify({"error": "error url_list"})
  39. embedding = await get_img_embedding(url_list)
  40. return jsonify(embedding)
  41. @server_bp.route("/delete", methods=["POST"])
  42. async def delete():
  43. body = await request.get_json()
  44. level = body.get("level")
  45. params = body.get("params")
  46. if not level or not params:
  47. return jsonify({"error": "error level or params"})
  48. resource = get_resource_manager()
  49. del_task = DeleteTask(resource)
  50. response = await del_task.deal(level, params)
  51. return jsonify(response)
  52. @server_bp.route("/chunk", methods=["POST"])
  53. async def chunk():
  54. body = await request.get_json()
  55. text = body.get("text", "")
  56. ori_doc_id = body.get("doc_id")
  57. text = text.strip()
  58. if not text:
  59. return jsonify({"error": "error text"})
  60. resource = get_resource_manager()
  61. # generate doc id
  62. if ori_doc_id:
  63. body["re_chunk"] = True
  64. doc_id = ori_doc_id
  65. else:
  66. doc_id = f"doc-{uuid.uuid4()}"
  67. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  68. doc_id = await chunk_task.deal(body)
  69. return jsonify({"doc_id": doc_id})
  70. @server_bp.route("/search", methods=["POST"])
  71. async def search():
  72. """
  73. filters: Dict[str, Any], # 条件过滤
  74. query_vec: List[float], # query 的向量
  75. anns_field: str = "vector_text", # query指定的向量空间
  76. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  77. query_text: str = None, #是否通过 topic 倒排
  78. _source=False, # 是否返回元数据
  79. es_size: int = 10000, #es 第一层过滤数量
  80. sort_by: str = None, # 排序
  81. milvus_size: int = 10 # milvus粗排返回数量
  82. :return:
  83. """
  84. body = await request.get_json()
  85. # 解析数据
  86. search_type: str = body.get("search_type")
  87. filters: Dict[str, Any] = body.get("filters", {})
  88. anns_field: str = body.get("anns_field", "vector_text")
  89. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  90. query_text: str = body.get("query_text")
  91. _source: bool = body.get("_source", False)
  92. es_size: int = body.get("es_size", 10000)
  93. sort_by: str = body.get("sort_by")
  94. milvus_size: int = body.get("milvus", 20)
  95. limit: int = body.get("limit", 10)
  96. path_between_chunks: dict = body.get("path_between_chunks", {})
  97. if not query_text:
  98. return jsonify({"error": "error query_text"})
  99. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  100. resource = get_resource_manager()
  101. search_engine = HybridSearch(
  102. milvus_pool=resource.milvus_client,
  103. es_pool=resource.es_client,
  104. graph_pool=resource.graph_client,
  105. )
  106. try:
  107. match search_type:
  108. case "base":
  109. response = await search_engine.base_vector_search(
  110. query_vec=query_vector,
  111. anns_field=anns_field,
  112. search_params=search_params,
  113. limit=limit,
  114. )
  115. return jsonify(response), 200
  116. case "hybrid":
  117. response = await search_engine.hybrid_search(
  118. filters=filters,
  119. query_vec=query_vector,
  120. anns_field=anns_field,
  121. search_params=search_params,
  122. es_size=es_size,
  123. sort_by=sort_by,
  124. milvus_size=milvus_size,
  125. )
  126. return jsonify(response), 200
  127. case "hybrid2":
  128. co_fields = {"Entity": filters["entities"][0]}
  129. response = await search_engine.hybrid_search_with_graph(
  130. filters=filters,
  131. query_vec=query_vector,
  132. anns_field=anns_field,
  133. search_params=search_params,
  134. es_size=es_size,
  135. sort_by=sort_by,
  136. milvus_size=milvus_size,
  137. co_occurrence_fields=co_fields,
  138. shortest_path_fields=path_between_chunks,
  139. )
  140. return jsonify(response), 200
  141. case "strategy":
  142. return jsonify({"error": "strategy not implemented"}), 405
  143. case _:
  144. return jsonify({"error": "error search_type"}), 200
  145. except Exception as e:
  146. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  147. @server_bp.route("/dataset/list", methods=["GET"])
  148. async def dataset_list():
  149. resource = get_resource_manager()
  150. datasets = await Dataset(resource.mysql_client).select_dataset()
  151. # 创建所有任务
  152. tasks = [
  153. Contents(resource.mysql_client).select_count(dataset["id"])
  154. for dataset in datasets
  155. ]
  156. counts = await asyncio.gather(*tasks)
  157. # 组装数据
  158. data_list = [
  159. {
  160. "dataset_id": dataset["id"],
  161. "name": dataset["name"],
  162. "count": count,
  163. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  164. }
  165. for dataset, count in zip(datasets, counts)
  166. ]
  167. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  168. @server_bp.route("/dataset/add", methods=["POST"])
  169. async def add_dataset():
  170. resource = get_resource_manager()
  171. dataset_mapper = Dataset(resource.mysql_client)
  172. # 从请求体里取参数
  173. body = await request.get_json()
  174. name = body.get("name")
  175. if not name:
  176. return jsonify({"status_code": 400, "detail": "name is required"})
  177. # 执行新增
  178. dataset = await dataset_mapper.select_dataset_by_name(name)
  179. if dataset:
  180. return jsonify({"status_code": 400, "detail": "name is exist"})
  181. await dataset_mapper.add_dataset(name)
  182. new_dataset = await dataset_mapper.select_dataset_by_name(name)
  183. return jsonify(
  184. {
  185. "status_code": 200,
  186. "detail": "success",
  187. "data": {"datasetId": new_dataset[0]["id"]},
  188. }
  189. )
  190. @server_bp.route("/content/get", methods=["GET"])
  191. async def get_content():
  192. resource = get_resource_manager()
  193. contents = Contents(resource.mysql_client)
  194. # 获取请求参数
  195. doc_id = request.args.get("docId")
  196. if not doc_id:
  197. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  198. # 查询内容
  199. rows = await contents.select_content_by_doc_id(doc_id)
  200. if not rows:
  201. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  202. row = rows[0]
  203. return jsonify(
  204. {
  205. "status_code": 200,
  206. "detail": "success",
  207. "data": {
  208. "title": row.get("title", ""),
  209. "text": row.get("text", ""),
  210. "doc_id": row.get("doc_id", ""),
  211. },
  212. }
  213. )
  214. @server_bp.route("/content/list", methods=["GET"])
  215. async def content_list():
  216. resource = get_resource_manager()
  217. contents = Contents(resource.mysql_client)
  218. # 从 URL 查询参数获取分页和过滤参数
  219. page_num = int(request.args.get("page", 1))
  220. page_size = int(request.args.get("pageSize", 10))
  221. dataset_id = request.args.get("datasetId")
  222. doc_status = int(request.args.get("doc_status", 1))
  223. # order_by 可以用 JSON 字符串传递
  224. import json
  225. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  226. try:
  227. order_by = json.loads(order_by_str)
  228. except Exception:
  229. order_by = {"id": "desc"}
  230. # 调用 select_contents,获取分页字典
  231. result = await contents.select_contents(
  232. page_num=page_num,
  233. page_size=page_size,
  234. dataset_id=dataset_id,
  235. doc_status=doc_status,
  236. order_by=order_by,
  237. )
  238. # 格式化 entities,只保留必要字段
  239. entities = [
  240. {
  241. "doc_id": row["doc_id"],
  242. "title": row.get("title") or "",
  243. "text": row.get("text") or "",
  244. "statusDesc": "可用" if row.get("status") == 2 else "不可用",
  245. }
  246. for row in result["entities"]
  247. ]
  248. return jsonify(
  249. {
  250. "status_code": 200,
  251. "detail": "success",
  252. "data": {
  253. "entities": entities,
  254. "total_count": result["total_count"],
  255. "page": result["page"],
  256. "page_size": result["page_size"],
  257. "total_pages": result["total_pages"],
  258. },
  259. }
  260. )
  261. async def query_search(
  262. query_text,
  263. filters=None,
  264. search_type="",
  265. anns_field="vector_text",
  266. search_params=BASE_MILVUS_SEARCH_PARAMS,
  267. _source=False,
  268. es_size=10000,
  269. sort_by=None,
  270. milvus_size=20,
  271. limit=10,
  272. ):
  273. if filters is None:
  274. filters = {}
  275. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  276. resource = get_resource_manager()
  277. search_engine = HybridSearch(
  278. milvus_pool=resource.milvus_client,
  279. es_pool=resource.es_client,
  280. graph_pool=resource.graph_client,
  281. )
  282. try:
  283. match search_type:
  284. case "base":
  285. response = await search_engine.base_vector_search(
  286. query_vec=query_vector,
  287. anns_field=anns_field,
  288. search_params=search_params,
  289. limit=limit,
  290. )
  291. return response
  292. case "hybrid":
  293. response = await search_engine.hybrid_search(
  294. filters=filters,
  295. query_vec=query_vector,
  296. anns_field=anns_field,
  297. search_params=search_params,
  298. es_size=es_size,
  299. sort_by=sort_by,
  300. milvus_size=milvus_size,
  301. )
  302. case "strategy":
  303. return None
  304. case _:
  305. return None
  306. except Exception as e:
  307. return None
  308. if response is None:
  309. return None
  310. resource = get_resource_manager()
  311. content_chunk_mapper = ContentChunks(resource.mysql_client)
  312. res = []
  313. for result in response["results"]:
  314. content_chunks = await content_chunk_mapper.select_chunk_content(
  315. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  316. )
  317. if content_chunks:
  318. content_chunk = content_chunks[0]
  319. res.append(
  320. {
  321. "docId": content_chunk["doc_id"],
  322. "content": content_chunk["text"],
  323. "contentSummary": content_chunk["summary"],
  324. "score": result["score"],
  325. "datasetId": content_chunk["dataset_id"],
  326. }
  327. )
  328. return res[:limit]
  329. @server_bp.route("/query", methods=["GET"])
  330. async def query():
  331. query_text = request.args.get("query")
  332. dataset_ids = request.args.get("datasetIds").split(",")
  333. search_type = request.args.get("search_type", "hybrid")
  334. query_results = await query_search(
  335. query_text=query_text,
  336. filters={"dataset_id": dataset_ids},
  337. search_type=search_type,
  338. )
  339. resource = get_resource_manager()
  340. dataset_mapper = Dataset(resource.mysql_client)
  341. for result in query_results:
  342. datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
  343. if datasets:
  344. dataset_name = datasets[0]["name"]
  345. result["datasetName"] = dataset_name
  346. data = {"results": query_results}
  347. return jsonify({"status_code": 200, "detail": "success", "data": data})
  348. @server_bp.route("/chat", methods=["GET"])
  349. async def chat():
  350. query_text = request.args.get("query")
  351. dataset_id_strs = request.args.get("datasetIds")
  352. dataset_ids = dataset_id_strs.split(",")
  353. search_type = request.args.get("search_type", "hybrid")
  354. query_results = await query_search(
  355. query_text=query_text,
  356. filters={"dataset_id": dataset_ids},
  357. search_type=search_type,
  358. )
  359. resource = get_resource_manager()
  360. chat_result_mapper = ChatResult(resource.mysql_client)
  361. dataset_mapper = Dataset(resource.mysql_client)
  362. for result in query_results:
  363. datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
  364. if datasets:
  365. dataset_name = datasets[0]["name"]
  366. result["datasetName"] = dataset_name
  367. rag_chat_agent = RAGChatAgent()
  368. chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
  369. if chat_result["status"] == 0:
  370. study(query_text)
  371. llm_search = await rag_chat_agent.llm_search(query_text)
  372. decision = await rag_chat_agent.make_decision(chat_result, llm_search)
  373. data = {
  374. "results": query_results,
  375. "chat_res": decision["result"],
  376. "rag_summary": chat_result["summary"],
  377. "llm_summary": llm_search["answer"],
  378. }
  379. await chat_result_mapper.insert_chat_result(
  380. query_text,
  381. dataset_id_strs,
  382. json.dumps(query_results, ensure_ascii=False),
  383. chat_result["summary"],
  384. chat_result["relevance_score"],
  385. chat_result["status"],
  386. llm_search["answer"],
  387. llm_search["source"],
  388. llm_search["status"],
  389. decision["result"],
  390. is_web=1,
  391. )
  392. # data = {"results": query_results, "chat_res": 'chat_res', 'rag_summary': 'rag_summary', 'llm_summary': 'llm_summary'}
  393. return jsonify({"status_code": 200, "detail": "success", "data": data})
  394. @server_bp.route("/chunk/list", methods=["GET"])
  395. async def chunk_list():
  396. resource = get_resource_manager()
  397. content_chunk = ContentChunks(resource.mysql_client)
  398. # 从 URL 查询参数获取分页和过滤参数
  399. page_num = int(request.args.get("page", 1))
  400. page_size = int(request.args.get("pageSize", 10))
  401. doc_id = request.args.get("docId")
  402. if not doc_id:
  403. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  404. # 调用 select_contents,获取分页字典
  405. result = await content_chunk.select_chunk_contents(
  406. page_num=page_num, page_size=page_size, doc_id=doc_id
  407. )
  408. if not result:
  409. return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
  410. # 格式化 entities,只保留必要字段
  411. entities = [
  412. {
  413. "id": row["id"],
  414. "chunk_id": row["chunk_id"],
  415. "doc_id": row["doc_id"],
  416. "summary": row.get("summary") or "",
  417. "text": row.get("text") or "",
  418. "statusDesc": "可用" if row.get("chunk_status") == 2 else "不可用",
  419. }
  420. for row in result["entities"]
  421. ]
  422. return jsonify(
  423. {
  424. "status_code": 200,
  425. "detail": "success",
  426. "data": {
  427. "entities": entities,
  428. "total_count": result["total_count"],
  429. "page": result["page"],
  430. "page_size": result["page_size"],
  431. "total_pages": result["total_pages"],
  432. },
  433. }
  434. )
  435. @server_bp.route("/auto_rechunk", methods=["GET"])
  436. async def auto_rechunk():
  437. resource = get_resource_manager()
  438. auto_rechunk_task = AutoRechunkTask(mysql_client=resource.mysql_client)
  439. process_cnt = await auto_rechunk_task.deal()
  440. return jsonify({"status_code": 200, "detail": "success", "cnt": process_cnt})
  441. @server_bp.route("/build_graph", methods=["POST"])
  442. async def delete_task():
  443. body = await request.get_json()
  444. doc_id: str = body.get("doc_id")
  445. if not doc_id:
  446. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  447. resource = get_resource_manager()
  448. build_graph_task = BuildGraph(
  449. neo4j=resource.graph_client,
  450. es_client=resource.es_client,
  451. mysql_client=resource.mysql_client,
  452. )
  453. await build_graph_task.deal(doc_id)
  454. return jsonify({"status_code": 200, "detail": "success", "data": {}})
  455. @server_bp.route("/rag/search", methods=["POST"])
  456. async def rag_search():
  457. body = await request.get_json()
  458. query_text = body.get("queryText")
  459. dataset_id_strs = "11,12"
  460. dataset_ids = dataset_id_strs.split(",")
  461. search_type = "hybrid"
  462. query_results = await query_search(
  463. query_text=query_text,
  464. filters={"dataset_id": dataset_ids},
  465. search_type=search_type,
  466. limit=5,
  467. )
  468. resource = get_resource_manager()
  469. chat_result_mapper = ChatResult(resource.mysql_client)
  470. rag_chat_agent = RAGChatAgent()
  471. chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
  472. if chat_result["status"] == 0:
  473. study(query_text)
  474. llm_search = await rag_chat_agent.llm_search(query_text)
  475. decision = await rag_chat_agent.make_decision(chat_result, llm_search)
  476. data = {
  477. "result": decision["result"],
  478. "status": decision["status"],
  479. "relevance_score": decision["relevance_score"],
  480. }
  481. await chat_result_mapper.insert_chat_result(
  482. query_text,
  483. dataset_id_strs,
  484. json.dumps(query_results, ensure_ascii=False),
  485. chat_result["summary"],
  486. chat_result["relevance_score"],
  487. chat_result["status"],
  488. llm_search["answer"],
  489. llm_search["source"],
  490. llm_search["status"],
  491. decision["result"],
  492. )
  493. return jsonify({"status_code": 200, "detail": "success", "data": data})
  494. @server_bp.route("/chat/history", methods=["GET"])
  495. async def chat_history():
  496. page_num = int(request.args.get("page", 1))
  497. page_size = int(request.args.get("pageSize", 10))
  498. resource = get_resource_manager()
  499. chat_result_mapper = ChatResult(resource.mysql_client)
  500. result = await chat_result_mapper.select_chat_results(page_num, page_size)
  501. return jsonify(
  502. {
  503. "status_code": 200,
  504. "detail": "success",
  505. "data": {
  506. "entities": result["entities"],
  507. "total_count": result["total_count"],
  508. "page": result["page"],
  509. "page_size": result["page_size"],
  510. "total_pages": result["total_pages"],
  511. },
  512. }
  513. )