buleprint.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. import asyncio
  2. import traceback
  3. import uuid
  4. from typing import Dict, Any
  5. from quart import Blueprint, jsonify, request
  6. from quart_cors import cors
  7. from applications.config import (
  8. DEFAULT_MODEL,
  9. LOCAL_MODEL_CONFIG,
  10. BASE_MILVUS_SEARCH_PARAMS,
  11. )
  12. from applications.resource import get_resource_manager
  13. from applications.api import get_basic_embedding
  14. from applications.api import get_img_embedding
  15. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  16. from applications.search import HybridSearch
  17. from applications.utils.chat import ChatClassifier
  18. from applications.utils.mysql import Dataset, Contents, ContentChunks
  19. server_bp = Blueprint("api", __name__, url_prefix="/api")
  20. server_bp = cors(server_bp, allow_origin="*")
  21. @server_bp.route("/embed", methods=["POST"])
  22. async def embed():
  23. body = await request.get_json()
  24. text = body.get("text")
  25. model_name = body.get("model", DEFAULT_MODEL)
  26. if not LOCAL_MODEL_CONFIG.get(model_name):
  27. return jsonify({"error": "error model"})
  28. embedding = await get_basic_embedding(text, model_name)
  29. return jsonify({"embedding": embedding})
  30. @server_bp.route("/img_embed", methods=["POST"])
  31. async def img_embed():
  32. body = await request.get_json()
  33. url_list = body.get("url_list")
  34. if not url_list:
  35. return jsonify({"error": "error url_list"})
  36. embedding = await get_img_embedding(url_list)
  37. return jsonify(embedding)
  38. @server_bp.route("/delete", methods=["POST"])
  39. async def delete():
  40. body = await request.get_json()
  41. level = body.get("level")
  42. params = body.get("params")
  43. if not level or not params:
  44. return jsonify({"error": "error level or params"})
  45. resource = get_resource_manager()
  46. delete_task = DeleteTask(resource)
  47. response = await delete_task.deal(level, params)
  48. return jsonify(response)
  49. @server_bp.route("/chunk", methods=["POST"])
  50. async def chunk():
  51. body = await request.get_json()
  52. text = body.get("text", "")
  53. text = text.strip()
  54. if not text:
  55. return jsonify({"error": "error text"})
  56. resource = get_resource_manager()
  57. doc_id = f"doc-{uuid.uuid4()}"
  58. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  59. doc_id = await chunk_task.deal(body)
  60. return jsonify({"doc_id": doc_id})
  61. @server_bp.route("/search", methods=["POST"])
  62. async def search():
  63. """
  64. filters: Dict[str, Any], # 条件过滤
  65. query_vec: List[float], # query 的向量
  66. anns_field: str = "vector_text", # query指定的向量空间
  67. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  68. query_text: str = None, #是否通过 topic 倒排
  69. _source=False, # 是否返回元数据
  70. es_size: int = 10000, #es 第一层过滤数量
  71. sort_by: str = None, # 排序
  72. milvus_size: int = 10 # milvus粗排返回数量
  73. :return:
  74. """
  75. body = await request.get_json()
  76. # 解析数据
  77. search_type: str = body.get("search_type")
  78. filters: Dict[str, Any] = body.get("filters", {})
  79. anns_field: str = body.get("anns_field", "vector_text")
  80. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  81. query_text: str = body.get("query_text")
  82. _source: bool = body.get("_source", False)
  83. es_size: int = body.get("es_size", 10000)
  84. sort_by: str = body.get("sort_by")
  85. milvus_size: int = body.get("milvus", 20)
  86. limit: int = body.get("limit", 10)
  87. if not query_text:
  88. return jsonify({"error": "error query_text"})
  89. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  90. resource = get_resource_manager()
  91. search_engine = HybridSearch(
  92. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  93. )
  94. try:
  95. match search_type:
  96. case "base":
  97. response = await search_engine.base_vector_search(
  98. query_vec=query_vector,
  99. anns_field=anns_field,
  100. search_params=search_params,
  101. limit=limit,
  102. )
  103. return jsonify(response), 200
  104. case "hybrid":
  105. response = await search_engine.hybrid_search(
  106. filters=filters,
  107. query_vec=query_vector,
  108. anns_field=anns_field,
  109. search_params=search_params,
  110. es_size=es_size,
  111. sort_by=sort_by,
  112. milvus_size=milvus_size,
  113. )
  114. return jsonify(response), 200
  115. case "strategy":
  116. return jsonify({"error": "strategy not implemented"}), 405
  117. case _:
  118. return jsonify({"error": "error search_type"}), 200
  119. except Exception as e:
  120. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  121. @server_bp.route("/dataset/list", methods=["GET"])
  122. async def dataset_list():
  123. resource = get_resource_manager()
  124. datasets = await Dataset(resource.mysql_client).select_dataset()
  125. # 创建所有任务
  126. tasks = [
  127. Contents(resource.mysql_client).select_count(dataset["id"])
  128. for dataset in datasets
  129. ]
  130. counts = await asyncio.gather(*tasks)
  131. # 组装数据
  132. data_list = [
  133. {
  134. "dataset_id": dataset["id"],
  135. "name": dataset["name"],
  136. "count": count,
  137. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  138. }
  139. for dataset, count in zip(datasets, counts)
  140. ]
  141. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  142. @server_bp.route("/dataset/add", methods=["POST"])
  143. async def add_dataset():
  144. resource = get_resource_manager()
  145. dataset = Dataset(resource.mysql_client)
  146. # 从请求体里取参数
  147. body = await request.get_json()
  148. name = body.get("name")
  149. if not name:
  150. return jsonify({"status_code": 400, "detail": "name is required"})
  151. # 执行新增
  152. await dataset.add_dataset(name)
  153. return jsonify({"status_code": 200, "detail": "success"})
  154. @server_bp.route("/content/get", methods=["GET"])
  155. async def get_content():
  156. resource = get_resource_manager()
  157. contents = Contents(resource.mysql_client)
  158. # 获取请求参数
  159. doc_id = request.args.get("docId")
  160. if not doc_id:
  161. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  162. # 查询内容
  163. rows = await contents.select_content_by_doc_id(doc_id)
  164. if not rows:
  165. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  166. row = rows[0]
  167. return jsonify(
  168. {
  169. "status_code": 200,
  170. "detail": "success",
  171. "data": {
  172. "title": row.get("title", ""),
  173. "text": row.get("text", ""),
  174. "doc_id": row.get("doc_id", ""),
  175. },
  176. }
  177. )
  178. @server_bp.route("/content/list", methods=["GET"])
  179. async def content_list():
  180. resource = get_resource_manager()
  181. contents = Contents(resource.mysql_client)
  182. # 从 URL 查询参数获取分页和过滤参数
  183. page_num = int(request.args.get("page", 1))
  184. page_size = int(request.args.get("pageSize", 10))
  185. dataset_id = request.args.get("datasetId")
  186. doc_status = int(request.args.get("doc_status", 1))
  187. # order_by 可以用 JSON 字符串传递
  188. import json
  189. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  190. try:
  191. order_by = json.loads(order_by_str)
  192. except Exception:
  193. order_by = {"id": "desc"}
  194. # 调用 select_contents,获取分页字典
  195. result = await contents.select_contents(
  196. page_num=page_num,
  197. page_size=page_size,
  198. dataset_id=dataset_id,
  199. doc_status=doc_status,
  200. order_by=order_by,
  201. )
  202. # 格式化 entities,只保留必要字段
  203. entities = [
  204. {
  205. "doc_id": row["doc_id"],
  206. "title": row.get("title") or "",
  207. "text": row.get("text") or "",
  208. }
  209. for row in result["entities"]
  210. ]
  211. return jsonify(
  212. {
  213. "status_code": 200,
  214. "detail": "success",
  215. "data": {
  216. "entities": entities,
  217. "total_count": result["total_count"],
  218. "page": result["page"],
  219. "page_size": result["page_size"],
  220. "total_pages": result["total_pages"],
  221. },
  222. }
  223. )
  224. async def query_search(
  225. query_text,
  226. filters=None,
  227. search_type="",
  228. anns_field="vector_text",
  229. search_params=BASE_MILVUS_SEARCH_PARAMS,
  230. _source=False,
  231. es_size=10000,
  232. sort_by=None,
  233. milvus_size=20,
  234. limit=10,
  235. ):
  236. if filters is None:
  237. filters = {}
  238. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  239. resource = get_resource_manager()
  240. search_engine = HybridSearch(
  241. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  242. )
  243. try:
  244. match search_type:
  245. case "base":
  246. response = await search_engine.base_vector_search(
  247. query_vec=query_vector,
  248. anns_field=anns_field,
  249. search_params=search_params,
  250. limit=limit,
  251. )
  252. return response
  253. case "hybrid":
  254. response = await search_engine.hybrid_search(
  255. filters=filters,
  256. query_vec=query_vector,
  257. anns_field=anns_field,
  258. search_params=search_params,
  259. es_size=es_size,
  260. sort_by=sort_by,
  261. milvus_size=milvus_size,
  262. )
  263. return response
  264. case "strategy":
  265. return None
  266. case _:
  267. return None
  268. except Exception as e:
  269. return None
  270. @server_bp.route("/query", methods=["GET"])
  271. async def query():
  272. query_text = request.args.get("query")
  273. dataset_ids = request.args.get("datasetIds").split(",")
  274. search_type = request.args.get("search_type", "hybrid")
  275. query_results = await query_search(
  276. query_text=query_text,
  277. filters={"dataset_id": dataset_ids},
  278. search_type=search_type,
  279. )
  280. resource = get_resource_manager()
  281. content_chunk_mapper = ContentChunks(resource.mysql_client)
  282. dataset_mapper = Dataset(resource.mysql_client)
  283. res = []
  284. for result in query_results["results"]:
  285. content_chunks = await content_chunk_mapper.select_chunk_content(
  286. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  287. )
  288. if not content_chunks:
  289. return jsonify(
  290. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  291. )
  292. content_chunk = content_chunks[0]
  293. datasets = await dataset_mapper.select_dataset_by_id(
  294. content_chunk["dataset_id"]
  295. )
  296. if not datasets:
  297. return jsonify(
  298. {"status_code": 500, "detail": "dataset not found", "data": {}}
  299. )
  300. dataset = datasets[0]
  301. dataset_name = None
  302. if dataset:
  303. dataset_name = dataset["name"]
  304. res.append(
  305. {
  306. "docId": content_chunk["doc_id"],
  307. "content": content_chunk["text"],
  308. "contentSummary": content_chunk["summary"],
  309. "score": result["score"],
  310. "datasetName": dataset_name,
  311. }
  312. )
  313. data = {"results": res}
  314. return jsonify({"status_code": 200, "detail": "success", "data": data})
  315. @server_bp.route("/chat", methods=["GET"])
  316. async def chat():
  317. query_text = request.args.get("query")
  318. dataset_ids = request.args.get("datasetIds").split(",")
  319. search_type = request.args.get("search_type", "hybrid")
  320. query_results = await query_search(
  321. query_text=query_text,
  322. filters={"dataset_id": dataset_ids},
  323. search_type=search_type,
  324. )
  325. resource = get_resource_manager()
  326. content_chunk_mapper = ContentChunks(resource.mysql_client)
  327. dataset_mapper = Dataset(resource.mysql_client)
  328. res = []
  329. for result in query_results["results"]:
  330. content_chunks = await content_chunk_mapper.select_chunk_content(
  331. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  332. )
  333. if not content_chunks:
  334. return jsonify(
  335. {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  336. )
  337. content_chunk = content_chunks[0]
  338. datasets = await dataset_mapper.select_dataset_by_id(
  339. content_chunk["dataset_id"]
  340. )
  341. if not datasets:
  342. return jsonify(
  343. {"status_code": 500, "detail": "dataset not found", "data": {}}
  344. )
  345. dataset = datasets[0]
  346. dataset_name = None
  347. if dataset:
  348. dataset_name = dataset["name"]
  349. res.append(
  350. {
  351. "docId": content_chunk["doc_id"],
  352. "content": content_chunk["text"],
  353. "contentSummary": content_chunk["summary"],
  354. "score": result["score"],
  355. "datasetName": dataset_name,
  356. }
  357. )
  358. chat_classifier = ChatClassifier()
  359. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  360. data = {"results": res, "chat_res": chat_res}
  361. return jsonify({"status_code": 200, "detail": "success", "data": data})