blueprint.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. import asyncio
  2. import json
  3. import traceback
  4. import uuid
  5. from typing import Dict, Any
  6. from quart import Blueprint, jsonify, request
  7. from quart_cors import cors
  8. from applications.api import get_basic_embedding
  9. from applications.api import get_img_embedding
  10. from applications.async_task import AutoRechunkTask, BuildGraph
  11. from applications.async_task import ChunkEmbeddingTask, DeleteTask, ChunkBooksTask
  12. from applications.config import (
  13. DEFAULT_MODEL,
  14. LOCAL_MODEL_CONFIG,
  15. BASE_MILVUS_SEARCH_PARAMS,
  16. )
  17. from applications.resource import get_resource_manager
  18. from applications.search import HybridSearch
  19. from applications.utils.chat import RAGChatAgent
  20. from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult
  21. from applications.utils.spider.study import study
  22. server_bp = Blueprint("api", __name__, url_prefix="/api")
  23. server_bp = cors(server_bp, allow_origin="*")
  24. @server_bp.route("/embed", methods=["POST"])
  25. async def embed():
  26. body = await request.get_json()
  27. text = body.get("text")
  28. model_name = body.get("model", DEFAULT_MODEL)
  29. if not LOCAL_MODEL_CONFIG.get(model_name):
  30. return jsonify({"error": "error model"})
  31. embedding = await get_basic_embedding(text, model_name)
  32. return jsonify({"embedding": embedding})
  33. @server_bp.route("/img_embed", methods=["POST"])
  34. async def img_embed():
  35. body = await request.get_json()
  36. url_list = body.get("url_list")
  37. if not url_list:
  38. return jsonify({"error": "error url_list"})
  39. embedding = await get_img_embedding(url_list)
  40. return jsonify(embedding)
  41. @server_bp.route("/delete", methods=["POST"])
  42. async def delete():
  43. body = await request.get_json()
  44. level = body.get("level")
  45. params = body.get("params")
  46. if not level or not params:
  47. return jsonify({"error": "error level or params"})
  48. resource = get_resource_manager()
  49. del_task = DeleteTask(resource)
  50. response = await del_task.deal(level, params)
  51. return jsonify(response)
  52. @server_bp.route("/chunk", methods=["POST"])
  53. async def chunk():
  54. body = await request.get_json()
  55. text = body.get("text", "")
  56. ori_doc_id = body.get("doc_id")
  57. text = text.strip()
  58. if not text:
  59. return jsonify({"error": "error text"})
  60. resource = get_resource_manager()
  61. # generate doc id
  62. if ori_doc_id:
  63. body["re_chunk"] = True
  64. doc_id = ori_doc_id
  65. else:
  66. doc_id = f"doc-{uuid.uuid4()}"
  67. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  68. doc_id = await chunk_task.deal(body)
  69. return jsonify({"doc_id": doc_id})
  70. @server_bp.route("/chunk_book", methods=["POST"])
  71. async def chunk_book():
  72. body = await request.get_json()
  73. resource = get_resource_manager()
  74. doc_id = f"doc-{uuid.uuid4()}"
  75. chunk_task = ChunkBooksTask(doc_id=doc_id, resource=resource)
  76. doc_id = await chunk_task.deal(body)
  77. return jsonify({"doc_id": doc_id})
  78. @server_bp.route("/search", methods=["POST"])
  79. async def search():
  80. """
  81. filters: Dict[str, Any], # 条件过滤
  82. query_vec: List[float], # query 的向量
  83. anns_field: str = "vector_text", # query指定的向量空间
  84. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  85. query_text: str = None, #是否通过 topic 倒排
  86. _source=False, # 是否返回元数据
  87. es_size: int = 10000, #es 第一层过滤数量
  88. sort_by: str = None, # 排序
  89. milvus_size: int = 10 # milvus粗排返回数量
  90. :return:
  91. """
  92. body = await request.get_json()
  93. # 解析数据
  94. search_type: str = body.get("search_type")
  95. filters: Dict[str, Any] = body.get("filters", {})
  96. anns_field: str = body.get("anns_field", "vector_text")
  97. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  98. query_text: str = body.get("query_text")
  99. _source: bool = body.get("_source", False)
  100. es_size: int = body.get("es_size", 10000)
  101. sort_by: str = body.get("sort_by")
  102. milvus_size: int = body.get("milvus", 20)
  103. limit: int = body.get("limit", 10)
  104. path_between_chunks: dict = body.get("path_between_chunks", {})
  105. if not query_text:
  106. return jsonify({"error": "error query_text"})
  107. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  108. resource = get_resource_manager()
  109. search_engine = HybridSearch(
  110. milvus_pool=resource.milvus_client,
  111. es_pool=resource.es_client,
  112. graph_pool=resource.graph_client,
  113. )
  114. try:
  115. match search_type:
  116. case "base":
  117. response = await search_engine.base_vector_search(
  118. query_vec=query_vector,
  119. anns_field=anns_field,
  120. search_params=search_params,
  121. limit=limit,
  122. )
  123. return jsonify(response), 200
  124. case "hybrid":
  125. response = await search_engine.hybrid_search(
  126. filters=filters,
  127. query_vec=query_vector,
  128. anns_field=anns_field,
  129. search_params=search_params,
  130. es_size=es_size,
  131. sort_by=sort_by,
  132. milvus_size=milvus_size,
  133. )
  134. return jsonify(response), 200
  135. case "hybrid2":
  136. co_fields = {"Entity": filters["entities"][0]}
  137. response = await search_engine.hybrid_search_with_graph(
  138. filters=filters,
  139. query_vec=query_vector,
  140. anns_field=anns_field,
  141. search_params=search_params,
  142. es_size=es_size,
  143. sort_by=sort_by,
  144. milvus_size=milvus_size,
  145. co_occurrence_fields=co_fields,
  146. shortest_path_fields=path_between_chunks,
  147. )
  148. return jsonify(response), 200
  149. case "strategy":
  150. return jsonify({"error": "strategy not implemented"}), 405
  151. case _:
  152. return jsonify({"error": "error search_type"}), 200
  153. except Exception as e:
  154. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  155. @server_bp.route("/dataset/list", methods=["GET"])
  156. async def dataset_list():
  157. resource = get_resource_manager()
  158. datasets = await Dataset(resource.mysql_client).select_dataset()
  159. # 创建所有任务
  160. tasks = [
  161. Contents(resource.mysql_client).select_count(dataset["id"])
  162. for dataset in datasets
  163. ]
  164. counts = await asyncio.gather(*tasks)
  165. # 组装数据
  166. data_list = [
  167. {
  168. "dataset_id": dataset["id"],
  169. "name": dataset["name"],
  170. "count": count,
  171. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  172. }
  173. for dataset, count in zip(datasets, counts)
  174. ]
  175. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  176. @server_bp.route("/dataset/add", methods=["POST"])
  177. async def add_dataset():
  178. resource = get_resource_manager()
  179. dataset_mapper = Dataset(resource.mysql_client)
  180. # 从请求体里取参数
  181. body = await request.get_json()
  182. name = body.get("name")
  183. if not name:
  184. return jsonify({"status_code": 400, "detail": "name is required"})
  185. # 执行新增
  186. dataset = await dataset_mapper.select_dataset_by_name(name)
  187. if dataset:
  188. return jsonify({"status_code": 400, "detail": "name is exist"})
  189. await dataset_mapper.add_dataset(name)
  190. new_dataset = await dataset_mapper.select_dataset_by_name(name)
  191. return jsonify(
  192. {
  193. "status_code": 200,
  194. "detail": "success",
  195. "data": {"datasetId": new_dataset[0]["id"]},
  196. }
  197. )
  198. @server_bp.route("/content/get", methods=["GET"])
  199. async def get_content():
  200. resource = get_resource_manager()
  201. contents = Contents(resource.mysql_client)
  202. # 获取请求参数
  203. doc_id = request.args.get("docId")
  204. if not doc_id:
  205. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  206. # 查询内容
  207. rows = await contents.select_content_by_doc_id(doc_id)
  208. if not rows:
  209. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  210. row = rows[0]
  211. return jsonify(
  212. {
  213. "status_code": 200,
  214. "detail": "success",
  215. "data": {
  216. "title": row.get("title", ""),
  217. "text": row.get("text", ""),
  218. "doc_id": row.get("doc_id", ""),
  219. },
  220. }
  221. )
  222. @server_bp.route("/content/list", methods=["GET"])
  223. async def content_list():
  224. resource = get_resource_manager()
  225. contents = Contents(resource.mysql_client)
  226. # 从 URL 查询参数获取分页和过滤参数
  227. page_num = int(request.args.get("page", 1))
  228. page_size = int(request.args.get("pageSize", 10))
  229. dataset_id = request.args.get("datasetId")
  230. doc_status = int(request.args.get("doc_status", 1))
  231. # order_by 可以用 JSON 字符串传递
  232. import json
  233. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  234. try:
  235. order_by = json.loads(order_by_str)
  236. except Exception:
  237. order_by = {"id": "desc"}
  238. # 调用 select_contents,获取分页字典
  239. result = await contents.select_contents(
  240. page_num=page_num,
  241. page_size=page_size,
  242. dataset_id=dataset_id,
  243. doc_status=doc_status,
  244. order_by=order_by,
  245. )
  246. # 格式化 entities,只保留必要字段
  247. entities = [
  248. {
  249. "doc_id": row["doc_id"],
  250. "title": row.get("title") or "",
  251. "text": row.get("text") or "",
  252. "statusDesc": "可用" if row.get("status") == 2 else "不可用",
  253. }
  254. for row in result["entities"]
  255. ]
  256. return jsonify(
  257. {
  258. "status_code": 200,
  259. "detail": "success",
  260. "data": {
  261. "entities": entities,
  262. "total_count": result["total_count"],
  263. "page": result["page"],
  264. "page_size": result["page_size"],
  265. "total_pages": result["total_pages"],
  266. },
  267. }
  268. )
  269. async def query_search(
  270. query_text,
  271. filters=None,
  272. search_type="",
  273. anns_field="vector_text",
  274. search_params=BASE_MILVUS_SEARCH_PARAMS,
  275. _source=False,
  276. es_size=10000,
  277. sort_by=None,
  278. milvus_size=20,
  279. limit=10,
  280. ):
  281. if filters is None:
  282. filters = {}
  283. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  284. resource = get_resource_manager()
  285. search_engine = HybridSearch(
  286. milvus_pool=resource.milvus_client,
  287. es_pool=resource.es_client,
  288. graph_pool=resource.graph_client,
  289. )
  290. try:
  291. match search_type:
  292. case "base":
  293. response = await search_engine.base_vector_search(
  294. query_vec=query_vector,
  295. anns_field=anns_field,
  296. search_params=search_params,
  297. limit=limit,
  298. )
  299. return response
  300. case "hybrid":
  301. response = await search_engine.hybrid_search(
  302. filters=filters,
  303. query_vec=query_vector,
  304. anns_field=anns_field,
  305. search_params=search_params,
  306. es_size=es_size,
  307. sort_by=sort_by,
  308. milvus_size=milvus_size,
  309. )
  310. case "strategy":
  311. return None
  312. case _:
  313. return None
  314. except Exception as e:
  315. return None
  316. if response is None:
  317. return None
  318. resource = get_resource_manager()
  319. content_chunk_mapper = ContentChunks(resource.mysql_client)
  320. res = []
  321. for result in response["results"]:
  322. content_chunks = await content_chunk_mapper.select_chunk_content(
  323. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  324. )
  325. if content_chunks:
  326. content_chunk = content_chunks[0]
  327. res.append(
  328. {
  329. "docId": content_chunk["doc_id"],
  330. "content": content_chunk["text"],
  331. "contentSummary": content_chunk["summary"],
  332. "score": result["score"],
  333. "datasetId": content_chunk["dataset_id"],
  334. }
  335. )
  336. return res[:limit]
  337. @server_bp.route("/query", methods=["GET"])
  338. async def query():
  339. query_text = request.args.get("query")
  340. dataset_ids = request.args.get("datasetIds").split(",")
  341. search_type = request.args.get("search_type", "hybrid")
  342. query_results = await query_search(
  343. query_text=query_text,
  344. filters={"dataset_id": dataset_ids},
  345. search_type=search_type,
  346. )
  347. resource = get_resource_manager()
  348. dataset_mapper = Dataset(resource.mysql_client)
  349. for result in query_results:
  350. datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
  351. if datasets:
  352. dataset_name = datasets[0]["name"]
  353. result["datasetName"] = dataset_name
  354. data = {"results": query_results}
  355. return jsonify({"status_code": 200, "detail": "success", "data": data})
  356. @server_bp.route("/chat", methods=["GET"])
  357. async def chat():
  358. query_text = request.args.get("query")
  359. dataset_id_strs = request.args.get("datasetIds")
  360. dataset_ids = dataset_id_strs.split(",")
  361. search_type = request.args.get("search_type", "hybrid")
  362. query_results = await query_search(
  363. query_text=query_text,
  364. filters={"dataset_id": dataset_ids},
  365. search_type=search_type,
  366. )
  367. resource = get_resource_manager()
  368. chat_result_mapper = ChatResult(resource.mysql_client)
  369. dataset_mapper = Dataset(resource.mysql_client)
  370. for result in query_results:
  371. datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
  372. if datasets:
  373. dataset_name = datasets[0]["name"]
  374. result["datasetName"] = dataset_name
  375. rag_chat_agent = RAGChatAgent()
  376. chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
  377. # study_task_id = None
  378. # if chat_result["status"] == 0:
  379. # study_task_id = study(query_text)['task_id']
  380. llm_search = await rag_chat_agent.llm_search(query_text)
  381. decision = await rag_chat_agent.make_decision(chat_result, llm_search)
  382. data = {
  383. "results": query_results,
  384. "chat_res": decision["result"],
  385. "rag_summary": chat_result["summary"],
  386. "llm_summary": llm_search["answer"],
  387. # "used_tools": decision["used_tools"],
  388. }
  389. await chat_result_mapper.insert_chat_result(
  390. query_text,
  391. dataset_id_strs,
  392. json.dumps(query_results, ensure_ascii=False),
  393. chat_result["summary"],
  394. chat_result["relevance_score"],
  395. chat_result["status"],
  396. llm_search["answer"],
  397. llm_search["source"],
  398. llm_search["status"],
  399. decision["result"],
  400. is_web=1,
  401. )
  402. return jsonify({"status_code": 200, "detail": "success", "data": data})
  403. @server_bp.route("/chunk/list", methods=["GET"])
  404. async def chunk_list():
  405. resource = get_resource_manager()
  406. content_chunk = ContentChunks(resource.mysql_client)
  407. # 从 URL 查询参数获取分页和过滤参数
  408. page_num = int(request.args.get("page", 1))
  409. page_size = int(request.args.get("pageSize", 10))
  410. doc_id = request.args.get("docId")
  411. if not doc_id:
  412. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  413. # 调用 select_contents,获取分页字典
  414. result = await content_chunk.select_chunk_contents(
  415. page_num=page_num, page_size=page_size, doc_id=doc_id
  416. )
  417. if not result:
  418. return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
  419. # 格式化 entities,只保留必要字段
  420. entities = [
  421. {
  422. "id": row["id"],
  423. "chunk_id": row["chunk_id"],
  424. "doc_id": row["doc_id"],
  425. "summary": row.get("summary") or "",
  426. "text": row.get("text") or "",
  427. "statusDesc": "可用" if row.get("chunk_status") == 2 else "不可用",
  428. }
  429. for row in result["entities"]
  430. ]
  431. return jsonify(
  432. {
  433. "status_code": 200,
  434. "detail": "success",
  435. "data": {
  436. "entities": entities,
  437. "total_count": result["total_count"],
  438. "page": result["page"],
  439. "page_size": result["page_size"],
  440. "total_pages": result["total_pages"],
  441. },
  442. }
  443. )
  444. @server_bp.route("/auto_rechunk", methods=["GET"])
  445. async def auto_rechunk():
  446. resource = get_resource_manager()
  447. auto_rechunk_task = AutoRechunkTask(mysql_client=resource.mysql_client)
  448. process_cnt = await auto_rechunk_task.deal()
  449. return jsonify({"status_code": 200, "detail": "success", "cnt": process_cnt})
  450. @server_bp.route("/build_graph", methods=["POST"])
  451. async def delete_task():
  452. body = await request.get_json()
  453. doc_id: str = body.get("doc_id")
  454. dataset_id: str = body.get("dataset_id", 12)
  455. batch: bool = body.get("batch_process", False)
  456. resource = get_resource_manager()
  457. build_graph_task = BuildGraph(
  458. neo4j=resource.graph_client,
  459. es_client=resource.es_client,
  460. mysql_client=resource.mysql_client,
  461. )
  462. if batch:
  463. await build_graph_task.deal_batch(dataset_id)
  464. else:
  465. await build_graph_task.deal(doc_id)
  466. return jsonify({"status_code": 200, "detail": "success", "data": {}})
  467. @server_bp.route("/rag/search", methods=["POST"])
  468. async def rag_search():
  469. body = await request.get_json()
  470. query_text = body.get("queryText")
  471. rag_chat_agent = RAGChatAgent()
  472. spilt_query = await rag_chat_agent.split_query(query_text)
  473. split_questions = spilt_query["split_questions"]
  474. split_questions.append(query_text)
  475. # 使用asyncio.gather并行处理每个问题
  476. tasks = [
  477. process_question(question, query_text, rag_chat_agent)
  478. for question in split_questions
  479. ]
  480. # 等待所有任务完成并收集结果
  481. data_list = await asyncio.gather(*tasks)
  482. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  483. async def process_question(question, query_text, rag_chat_agent):
  484. try:
  485. dataset_id_strs = "11,12"
  486. dataset_ids = dataset_id_strs.split(",")
  487. search_type = "hybrid"
  488. # 执行查询任务
  489. query_results = await query_search(
  490. query_text=question,
  491. filters={"dataset_id": dataset_ids},
  492. search_type=search_type,
  493. )
  494. resource = get_resource_manager()
  495. chat_result_mapper = ChatResult(resource.mysql_client)
  496. # 异步执行 chat 与 deepseek 的对话
  497. chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
  498. # # 判断是否需要执行 study
  499. study_task_id = None
  500. if chat_result["status"] == 0:
  501. study_task_id = study(question)["task_id"]
  502. # 异步获取 LLM 搜索结果
  503. llm_search_result = await rag_chat_agent.llm_search(question)
  504. # 执行决策逻辑
  505. decision = await rag_chat_agent.make_decision(chat_result, llm_search_result)
  506. # 构建返回的数据
  507. data = {
  508. "query": question,
  509. "result": decision["result"],
  510. "status": decision["status"],
  511. "relevance_score": decision["relevance_score"],
  512. # "used_tools": decision["used_tools"],
  513. }
  514. # 插入数据库
  515. await chat_result_mapper.insert_chat_result(
  516. question,
  517. dataset_id_strs,
  518. json.dumps(query_results, ensure_ascii=False),
  519. chat_result["summary"],
  520. chat_result["relevance_score"],
  521. chat_result["status"],
  522. llm_search_result["answer"],
  523. llm_search_result["source"],
  524. llm_search_result["status"],
  525. decision["result"],
  526. study_task_id,
  527. )
  528. return data
  529. except Exception as e:
  530. print(f"Error processing question: {question}. Error: {str(e)}")
  531. return {"query": question, "error": str(e)}
  532. @server_bp.route("/chat/history", methods=["GET"])
  533. async def chat_history():
  534. page_num = int(request.args.get("page", 1))
  535. page_size = int(request.args.get("pageSize", 10))
  536. resource = get_resource_manager()
  537. chat_result_mapper = ChatResult(resource.mysql_client)
  538. result = await chat_result_mapper.select_chat_results(page_num, page_size)
  539. return jsonify(
  540. {
  541. "status_code": 200,
  542. "detail": "success",
  543. "data": {
  544. "entities": result["entities"],
  545. "total_count": result["total_count"],
  546. "page": result["page"],
  547. "page_size": result["page_size"],
  548. "total_pages": result["total_pages"],
  549. },
  550. }
  551. )