buleprint.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import asyncio
  2. import json
  3. import traceback
  4. import uuid
  5. from typing import Dict, Any
  6. from quart import Blueprint, jsonify, request
  7. from quart_cors import cors
  8. from applications.config import (
  9. DEFAULT_MODEL,
  10. LOCAL_MODEL_CONFIG,
  11. ChunkerConfig,
  12. BASE_MILVUS_SEARCH_PARAMS,
  13. )
  14. from applications.resource import get_resource_manager
  15. from applications.api import get_basic_embedding
  16. from applications.api import get_img_embedding
  17. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  18. from applications.search import HybridSearch
  19. from applications.utils.chat import ChatClassifier
  20. from applications.utils.mysql import Dataset, Contents, ContentChunks
  21. from applications.utils.mysql.mapper import ChatRes
  22. server_bp = Blueprint("api", __name__, url_prefix="/api")
  23. server_bp = cors(server_bp, allow_origin="*")
  24. @server_bp.route("/embed", methods=["POST"])
  25. async def embed():
  26. body = await request.get_json()
  27. text = body.get("text")
  28. model_name = body.get("model", DEFAULT_MODEL)
  29. if not LOCAL_MODEL_CONFIG.get(model_name):
  30. return jsonify({"error": "error model"})
  31. embedding = await get_basic_embedding(text, model_name)
  32. return jsonify({"embedding": embedding})
  33. @server_bp.route("/img_embed", methods=["POST"])
  34. async def img_embed():
  35. body = await request.get_json()
  36. url_list = body.get("url_list")
  37. if not url_list:
  38. return jsonify({"error": "error url_list"})
  39. embedding = await get_img_embedding(url_list)
  40. return jsonify(embedding)
  41. @server_bp.route("/delete", methods=["POST"])
  42. async def delete():
  43. body = await request.get_json()
  44. level = body.get("level")
  45. params = body.get("params")
  46. if not level or not params:
  47. return jsonify({"error": "error level or params"})
  48. resource = get_resource_manager()
  49. delete_task = DeleteTask(resource)
  50. response = await delete_task.deal(level, params)
  51. return jsonify(response)
  52. @server_bp.route("/chunk", methods=["POST"])
  53. async def chunk():
  54. body = await request.get_json()
  55. text = body.get("text", "")
  56. text = text.strip()
  57. if not text:
  58. return jsonify({"error": "error text"})
  59. resource = get_resource_manager()
  60. doc_id = f"doc-{uuid.uuid4()}"
  61. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  62. doc_id = await chunk_task.deal(body)
  63. return jsonify({"doc_id": doc_id})
  64. @server_bp.route("/search", methods=["POST"])
  65. async def search():
  66. """
  67. filters: Dict[str, Any], # 条件过滤
  68. query_vec: List[float], # query 的向量
  69. anns_field: str = "vector_text", # query指定的向量空间
  70. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  71. query_text: str = None, #是否通过 topic 倒排
  72. _source=False, # 是否返回元数据
  73. es_size: int = 10000, #es 第一层过滤数量
  74. sort_by: str = None, # 排序
  75. milvus_size: int = 10 # milvus粗排返回数量
  76. :return:
  77. """
  78. body = await request.get_json()
  79. # 解析数据
  80. search_type: str = body.get("search_type")
  81. filters: Dict[str, Any] = body.get("filters", {})
  82. anns_field: str = body.get("anns_field", "vector_text")
  83. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  84. query_text: str = body.get("query_text")
  85. _source: bool = body.get("_source", False)
  86. es_size: int = body.get("es_size", 10000)
  87. sort_by: str = body.get("sort_by")
  88. milvus_size: int = body.get("milvus", 20)
  89. limit: int = body.get("limit", 10)
  90. if not query_text:
  91. return jsonify({"error": "error query_text"})
  92. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  93. resource = get_resource_manager()
  94. search_engine = HybridSearch(
  95. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  96. )
  97. try:
  98. match search_type:
  99. case "base":
  100. response = await search_engine.base_vector_search(
  101. query_vec=query_vector,
  102. anns_field=anns_field,
  103. search_params=search_params,
  104. limit=limit,
  105. )
  106. return jsonify(response), 200
  107. case "hybrid":
  108. response = await search_engine.hybrid_search(
  109. filters=filters,
  110. query_vec=query_vector,
  111. anns_field=anns_field,
  112. search_params=search_params,
  113. es_size=es_size,
  114. sort_by=sort_by,
  115. milvus_size=milvus_size,
  116. )
  117. return jsonify(response), 200
  118. case "strategy":
  119. return jsonify({"error": "strategy not implemented"}), 405
  120. case _:
  121. return jsonify({"error": "error search_type"}), 200
  122. except Exception as e:
  123. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  124. @server_bp.route("/dataset/list", methods=["GET"])
  125. async def dataset_list():
  126. resource = get_resource_manager()
  127. datasets = await Dataset(resource.mysql_client).select_dataset()
  128. # 创建所有任务
  129. tasks = [
  130. Contents(resource.mysql_client).select_count(dataset["id"])
  131. for dataset in datasets
  132. ]
  133. counts = await asyncio.gather(*tasks)
  134. # 组装数据
  135. data_list = [
  136. {
  137. "dataset_id": dataset["id"],
  138. "name": dataset["name"],
  139. "count": count,
  140. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  141. }
  142. for dataset, count in zip(datasets, counts)
  143. ]
  144. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  145. @server_bp.route("/dataset/add", methods=["POST"])
  146. async def add_dataset():
  147. resource = get_resource_manager()
  148. dataset = Dataset(resource.mysql_client)
  149. # 从请求体里取参数
  150. body = await request.get_json()
  151. name = body.get("name")
  152. if not name:
  153. return jsonify({"status_code": 400, "detail": "name is required"})
  154. # 执行新增
  155. await dataset.add_dataset(name)
  156. return jsonify({"status_code": 200, "detail": "success"})
  157. @server_bp.route("/content/get", methods=["GET"])
  158. async def get_content():
  159. resource = get_resource_manager()
  160. contents = Contents(resource.mysql_client)
  161. # 获取请求参数
  162. doc_id = request.args.get("docId")
  163. if not doc_id:
  164. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  165. # 查询内容
  166. rows = await contents.select_content_by_doc_id(doc_id)
  167. if not rows:
  168. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  169. row = rows[0]
  170. return jsonify(
  171. {
  172. "status_code": 200,
  173. "detail": "success",
  174. "data": {
  175. "title": row.get("title", ""),
  176. "text": row.get("text", ""),
  177. "doc_id": row.get("doc_id", ""),
  178. },
  179. }
  180. )
  181. @server_bp.route("/content/list", methods=["GET"])
  182. async def content_list():
  183. resource = get_resource_manager()
  184. contents = Contents(resource.mysql_client)
  185. # 从 URL 查询参数获取分页和过滤参数
  186. page_num = int(request.args.get("page", 1))
  187. page_size = int(request.args.get("pageSize", 10))
  188. dataset_id = request.args.get("datasetId")
  189. doc_status = int(request.args.get("doc_status", 1))
  190. # order_by 可以用 JSON 字符串传递
  191. import json
  192. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  193. try:
  194. order_by = json.loads(order_by_str)
  195. except Exception:
  196. order_by = {"id": "desc"}
  197. # 调用 select_contents,获取分页字典
  198. result = await contents.select_contents(
  199. page_num=page_num,
  200. page_size=page_size,
  201. dataset_id=dataset_id,
  202. doc_status=doc_status,
  203. order_by=order_by,
  204. )
  205. # 格式化 entities,只保留必要字段
  206. entities = [
  207. {
  208. "doc_id": row["doc_id"],
  209. "title": row.get("title") or "",
  210. "text": row.get("text") or "",
  211. }
  212. for row in result["entities"]
  213. ]
  214. return jsonify(
  215. {
  216. "status_code": 200,
  217. "detail": "success",
  218. "data": {
  219. "entities": entities,
  220. "total_count": result["total_count"],
  221. "page": result["page"],
  222. "page_size": result["page_size"],
  223. "total_pages": result["total_pages"],
  224. },
  225. }
  226. )
  227. async def query_search(
  228. query_text,
  229. filters=None,
  230. search_type="",
  231. anns_field="vector_text",
  232. search_params=BASE_MILVUS_SEARCH_PARAMS,
  233. _source=False,
  234. es_size=10000,
  235. sort_by=None,
  236. milvus_size=20,
  237. limit=10,
  238. ):
  239. if filters is None:
  240. filters = {}
  241. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  242. resource = get_resource_manager()
  243. search_engine = HybridSearch(
  244. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  245. )
  246. try:
  247. match search_type:
  248. case "base":
  249. response = await search_engine.base_vector_search(
  250. query_vec=query_vector,
  251. anns_field=anns_field,
  252. search_params=search_params,
  253. limit=limit,
  254. )
  255. return response
  256. case "hybrid":
  257. response = await search_engine.hybrid_search(
  258. filters=filters,
  259. query_vec=query_vector,
  260. anns_field=anns_field,
  261. search_params=search_params,
  262. es_size=es_size,
  263. sort_by=sort_by,
  264. milvus_size=milvus_size,
  265. )
  266. return response
  267. case "strategy":
  268. return None
  269. case _:
  270. return None
  271. except Exception as e:
  272. return None
  273. @server_bp.route("/query", methods=["GET"])
  274. async def query():
  275. query_text = request.args.get("query")
  276. dataset_ids = request.args.get("datasetIds").split(",")
  277. search_type = request.args.get("search_type", "hybrid")
  278. query_results = await query_search(
  279. query_text=query_text,
  280. filters={"dataset_id": dataset_ids},
  281. search_type=search_type,
  282. )
  283. resource = get_resource_manager()
  284. content_chunk_mapper = ContentChunks(resource.mysql_client)
  285. dataset_mapper = Dataset(resource.mysql_client)
  286. res = []
  287. for result in query_results["results"]:
  288. content_chunks = await content_chunk_mapper.select_chunk_content(
  289. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  290. )
  291. if not content_chunks:
  292. return jsonify(
  293. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  294. )
  295. content_chunk = content_chunks[0]
  296. datasets = await dataset_mapper.select_dataset_by_id(
  297. content_chunk["dataset_id"]
  298. )
  299. if not datasets:
  300. return jsonify(
  301. {"status_code": 500, "detail": "dataset not found", "data": {}}
  302. )
  303. dataset = datasets[0]
  304. dataset_name = None
  305. if dataset:
  306. dataset_name = dataset["name"]
  307. res.append(
  308. {
  309. "docId": content_chunk["doc_id"],
  310. "content": content_chunk["text"],
  311. "contentSummary": content_chunk["summary"],
  312. "score": result["score"],
  313. "datasetName": dataset_name,
  314. }
  315. )
  316. data = {"results": res}
  317. return jsonify({"status_code": 200, "detail": "success", "data": data})
  318. @server_bp.route("/chat", methods=["GET"])
  319. async def chat():
  320. query_text = request.args.get("query")
  321. dataset_id_strs = request.args.get("datasetIds")
  322. dataset_ids = dataset_id_strs.split(",")
  323. search_type = request.args.get("search_type", "hybrid")
  324. query_results = await query_search(
  325. query_text=query_text,
  326. filters={"dataset_id": dataset_ids},
  327. search_type=search_type,
  328. )
  329. resource = get_resource_manager()
  330. content_chunk_mapper = ContentChunks(resource.mysql_client)
  331. dataset_mapper = Dataset(resource.mysql_client)
  332. chat_res_mapper = ChatRes(resource.mysql_client)
  333. res = []
  334. for result in query_results["results"]:
  335. content_chunks = await content_chunk_mapper.select_chunk_content(
  336. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  337. )
  338. if not content_chunks:
  339. return jsonify(
  340. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  341. )
  342. content_chunk = content_chunks[0]
  343. datasets = await dataset_mapper.select_dataset_by_id(
  344. content_chunk["dataset_id"]
  345. )
  346. if not datasets:
  347. return jsonify(
  348. {"status_code": 500, "detail": "dataset not found", "data": {}}
  349. )
  350. dataset = datasets[0]
  351. dataset_name = None
  352. if dataset:
  353. dataset_name = dataset["name"]
  354. res.append(
  355. {
  356. "docId": content_chunk["doc_id"],
  357. "content": content_chunk["text"],
  358. "contentSummary": content_chunk["summary"],
  359. "score": result["score"],
  360. "datasetName": dataset_name,
  361. }
  362. )
  363. chat_classifier = ChatClassifier()
  364. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  365. data = {"results": res, "chat_res": chat_res["summary"]}
  366. await chat_res_mapper.insert_chat_res(
  367. query_text,
  368. dataset_id_strs,
  369. json.dumps(data, ensure_ascii=False),
  370. chat_res["summary"],
  371. chat_res["relevance_score"],
  372. )
  373. return jsonify({"status_code": 200, "detail": "success", "data": data})
  374. @server_bp.route("/chunk/list", methods=["GET"])
  375. async def chunk_list():
  376. resource = get_resource_manager()
  377. content_chunk = ContentChunks(resource.mysql_client)
  378. # 从 URL 查询参数获取分页和过滤参数
  379. page_num = int(request.args.get("page", 1))
  380. page_size = int(request.args.get("pageSize", 10))
  381. doc_id = request.args.get("docId")
  382. if not doc_id:
  383. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  384. # 调用 select_contents,获取分页字典
  385. result = await content_chunk.select_chunk_contents(
  386. page_num=page_num, page_size=page_size, doc_id=doc_id
  387. )
  388. if not result:
  389. return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
  390. # 格式化 entities,只保留必要字段
  391. entities = [
  392. {
  393. "id": row["id"],
  394. "chunk_id": row["chunk_id"],
  395. "doc_id": row["doc_id"],
  396. "summary": row.get("summary") or "",
  397. "text": row.get("text") or "",
  398. }
  399. for row in result["entities"]
  400. ]
  401. return jsonify(
  402. {
  403. "status_code": 200,
  404. "detail": "success",
  405. "data": {
  406. "entities": entities,
  407. "total_count": result["total_count"],
  408. "page": result["page"],
  409. "page_size": result["page_size"],
  410. "total_pages": result["total_pages"],
  411. },
  412. }
  413. )