buleprint.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. import asyncio
  2. import traceback
  3. import uuid
  4. from typing import Dict, Any
  5. from quart import Blueprint, jsonify, request
  6. from quart_cors import cors
  7. from applications.config import (
  8. DEFAULT_MODEL,
  9. LOCAL_MODEL_CONFIG,
  10. ChunkerConfig,
  11. BASE_MILVUS_SEARCH_PARAMS,
  12. )
  13. from applications.resource import get_resource_manager
  14. from applications.api import get_basic_embedding
  15. from applications.api import get_img_embedding
  16. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  17. from applications.search import HybridSearch
  18. from applications.utils.chat import ChatClassifier
  19. from applications.utils.mysql import Dataset, Contents, ContentChunks
  20. server_bp = Blueprint("api", __name__, url_prefix="/api")
  21. server_bp = cors(server_bp, allow_origin="*")
  22. @server_bp.route("/embed", methods=["POST"])
  23. async def embed():
  24. body = await request.get_json()
  25. text = body.get("text")
  26. model_name = body.get("model", DEFAULT_MODEL)
  27. if not LOCAL_MODEL_CONFIG.get(model_name):
  28. return jsonify({"error": "error model"})
  29. embedding = await get_basic_embedding(text, model_name)
  30. return jsonify({"embedding": embedding})
  31. @server_bp.route("/img_embed", methods=["POST"])
  32. async def img_embed():
  33. body = await request.get_json()
  34. url_list = body.get("url_list")
  35. if not url_list:
  36. return jsonify({"error": "error url_list"})
  37. embedding = await get_img_embedding(url_list)
  38. return jsonify(embedding)
  39. @server_bp.route("/delete", methods=["POST"])
  40. async def delete():
  41. body = await request.get_json()
  42. level = body.get("level")
  43. params = body.get("params")
  44. if not level or not params:
  45. return jsonify({"error": "error level or params"})
  46. resource = get_resource_manager()
  47. delete_task = DeleteTask(resource)
  48. response = await delete_task.deal(level, params)
  49. return jsonify(response)
  50. @server_bp.route("/chunk", methods=["POST"])
  51. async def chunk():
  52. body = await request.get_json()
  53. text = body.get("text", "")
  54. text = text.strip()
  55. if not text:
  56. return jsonify({"error": "error text"})
  57. resource = get_resource_manager()
  58. doc_id = f"doc-{uuid.uuid4()}"
  59. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  60. doc_id = await chunk_task.deal(body)
  61. return jsonify({"doc_id": doc_id})
  62. @server_bp.route("/search", methods=["POST"])
  63. async def search():
  64. """
  65. filters: Dict[str, Any], # 条件过滤
  66. query_vec: List[float], # query 的向量
  67. anns_field: str = "vector_text", # query指定的向量空间
  68. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  69. query_text: str = None, #是否通过 topic 倒排
  70. _source=False, # 是否返回元数据
  71. es_size: int = 10000, #es 第一层过滤数量
  72. sort_by: str = None, # 排序
  73. milvus_size: int = 10 # milvus粗排返回数量
  74. :return:
  75. """
  76. body = await request.get_json()
  77. # 解析数据
  78. search_type: str = body.get("search_type")
  79. filters: Dict[str, Any] = body.get("filters", {})
  80. anns_field: str = body.get("anns_field", "vector_text")
  81. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  82. query_text: str = body.get("query_text")
  83. _source: bool = body.get("_source", False)
  84. es_size: int = body.get("es_size", 10000)
  85. sort_by: str = body.get("sort_by")
  86. milvus_size: int = body.get("milvus", 20)
  87. limit: int = body.get("limit", 10)
  88. if not query_text:
  89. return jsonify({"error": "error query_text"})
  90. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  91. resource = get_resource_manager()
  92. search_engine = HybridSearch(
  93. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  94. )
  95. try:
  96. match search_type:
  97. case "base":
  98. response = await search_engine.base_vector_search(
  99. query_vec=query_vector,
  100. anns_field=anns_field,
  101. search_params=search_params,
  102. limit=limit,
  103. )
  104. return jsonify(response), 200
  105. case "hybrid":
  106. response = await search_engine.hybrid_search(
  107. filters=filters,
  108. query_vec=query_vector,
  109. anns_field=anns_field,
  110. search_params=search_params,
  111. es_size=es_size,
  112. sort_by=sort_by,
  113. milvus_size=milvus_size,
  114. )
  115. return jsonify(response), 200
  116. case "strategy":
  117. return jsonify({"error": "strategy not implemented"}), 405
  118. case _:
  119. return jsonify({"error": "error search_type"}), 200
  120. except Exception as e:
  121. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  122. @server_bp.route("/dataset/list", methods=["GET"])
  123. async def dataset_list():
  124. resource = get_resource_manager()
  125. datasets = await Dataset(resource.mysql_client).select_dataset()
  126. # 创建所有任务
  127. tasks = [
  128. Contents(resource.mysql_client).select_count(dataset["id"])
  129. for dataset in datasets
  130. ]
  131. counts = await asyncio.gather(*tasks)
  132. # 组装数据
  133. data_list = [
  134. {
  135. "dataset_id": dataset["id"],
  136. "name": dataset["name"],
  137. "count": count,
  138. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  139. }
  140. for dataset, count in zip(datasets, counts)
  141. ]
  142. return jsonify({
  143. "status_code": 200,
  144. "detail": "success",
  145. "data": data_list
  146. })
  147. @server_bp.route("/dataset/add", methods=["POST"])
  148. async def add_dataset():
  149. resource = get_resource_manager()
  150. dataset = Dataset(resource.mysql_client)
  151. # 从请求体里取参数
  152. body = await request.get_json()
  153. name = body.get("name")
  154. if not name:
  155. return jsonify({
  156. "status_code": 400,
  157. "detail": "name is required"
  158. })
  159. # 执行新增
  160. await dataset.add_dataset(name)
  161. return jsonify({
  162. "status_code": 200,
  163. "detail": "success"
  164. })
  165. @server_bp.route("/content/get", methods=["GET"])
  166. async def get_content():
  167. resource = get_resource_manager()
  168. contents = Contents(resource.mysql_client)
  169. # 获取请求参数
  170. doc_id = request.args.get("docId")
  171. if not doc_id:
  172. return jsonify({
  173. "status_code": 400,
  174. "detail": "doc_id is required",
  175. "data": {}
  176. })
  177. # 查询内容
  178. rows = await contents.select_content_by_doc_id(doc_id)
  179. if not rows:
  180. return jsonify({
  181. "status_code": 404,
  182. "detail": "content not found",
  183. "data": {}
  184. })
  185. row = rows[0]
  186. return jsonify({
  187. "status_code": 200,
  188. "detail": "success",
  189. "data": {
  190. "title": row.get("title", ""),
  191. "text": row.get("text", ""),
  192. "doc_id": row.get("doc_id", "")
  193. }
  194. })
  195. @server_bp.route("/content/list", methods=["GET"])
  196. async def content_list():
  197. resource = get_resource_manager()
  198. contents = Contents(resource.mysql_client)
  199. # 从 URL 查询参数获取分页和过滤参数
  200. page_num = int(request.args.get("page", 1))
  201. page_size = int(request.args.get("pageSize", 10))
  202. dataset_id = request.args.get("datasetId")
  203. doc_status = int(request.args.get("doc_status", 1))
  204. # order_by 可以用 JSON 字符串传递
  205. import json
  206. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  207. try:
  208. order_by = json.loads(order_by_str)
  209. except Exception:
  210. order_by = {"id": "desc"}
  211. # 调用 select_contents,获取分页字典
  212. result = await contents.select_contents(
  213. page_num=page_num,
  214. page_size=page_size,
  215. dataset_id=dataset_id,
  216. doc_status=doc_status,
  217. order_by=order_by,
  218. )
  219. # 格式化 entities,只保留必要字段
  220. entities = [
  221. {
  222. "doc_id": row["doc_id"],
  223. "title": row.get("title") or "",
  224. "text": row.get("text") or "",
  225. }
  226. for row in result["entities"]
  227. ]
  228. return jsonify({
  229. "status_code": 200,
  230. "detail": "success",
  231. "data": {
  232. "entities": entities,
  233. "total_count": result["total_count"],
  234. "page": result["page"],
  235. "page_size": result["page_size"],
  236. "total_pages": result["total_pages"]
  237. }
  238. })
  239. async def query_search(query_text, filters=None, search_type='', anns_field='vector_text',
  240. search_params=BASE_MILVUS_SEARCH_PARAMS, _source=False, es_size=10000, sort_by=None,
  241. milvus_size=20, limit=10):
  242. if filters is None:
  243. filters = {}
  244. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  245. resource = get_resource_manager()
  246. search_engine = HybridSearch(
  247. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  248. )
  249. try:
  250. match search_type:
  251. case "base":
  252. response = await search_engine.base_vector_search(
  253. query_vec=query_vector,
  254. anns_field=anns_field,
  255. search_params=search_params,
  256. limit=limit,
  257. )
  258. return response
  259. case "hybrid":
  260. response = await search_engine.hybrid_search(
  261. filters=filters,
  262. query_vec=query_vector,
  263. anns_field=anns_field,
  264. search_params=search_params,
  265. es_size=es_size,
  266. sort_by=sort_by,
  267. milvus_size=milvus_size,
  268. )
  269. return response
  270. case "strategy":
  271. return None
  272. case _:
  273. return None
  274. except Exception as e:
  275. return None
  276. @server_bp.route("/query", methods=["GET"])
  277. async def query():
  278. query_text = request.args.get("query")
  279. dataset_ids = request.args.get("datasetIds").split(",")
  280. search_type = request.args.get("search_type", "hybrid")
  281. query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
  282. search_type=search_type)
  283. resource = get_resource_manager()
  284. content_chunk_mapper = ContentChunks(resource.mysql_client)
  285. dataset_mapper = Dataset(resource.mysql_client)
  286. res = []
  287. for result in query_results['results']:
  288. content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
  289. chunk_id=result['chunk_id'])
  290. if not content_chunks:
  291. return jsonify({
  292. "status_code": 500,
  293. "detail": "content_chunk not found",
  294. "data": {}
  295. })
  296. content_chunk = content_chunks[0]
  297. datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
  298. if not datasets:
  299. return jsonify({
  300. "status_code": 500,
  301. "detail": "dataset not found",
  302. "data": {}
  303. })
  304. dataset = datasets[0]
  305. dataset_name = None
  306. if dataset:
  307. dataset_name = dataset['name']
  308. res.append(
  309. {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
  310. 'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
  311. data = {'results': res}
  312. return jsonify({'status_code': 200,
  313. 'detail': "success",
  314. 'data': data})
  315. @server_bp.route("/chat", methods=["GET"])
  316. async def chat():
  317. query_text = request.args.get("query")
  318. dataset_ids = request.args.get("datasetIds").split(",")
  319. search_type = request.args.get("search_type", "hybrid")
  320. query_results = await query_search(query_text=query_text, filters={"dataset_id": dataset_ids},
  321. search_type=search_type)
  322. resource = get_resource_manager()
  323. content_chunk_mapper = ContentChunks(resource.mysql_client)
  324. dataset_mapper = Dataset(resource.mysql_client)
  325. res = []
  326. for result in query_results['results']:
  327. content_chunks = await content_chunk_mapper.select_chunk_content(doc_id=result['doc_id'],
  328. chunk_id=result['chunk_id'])
  329. if not content_chunks:
  330. return jsonify({
  331. "status_code": 500,
  332. "detail": "content_chunk not found",
  333. "data": {}
  334. })
  335. content_chunk = content_chunks[0]
  336. datasets = await dataset_mapper.select_dataset_by_id(content_chunk['dataset_id'])
  337. if not datasets:
  338. return jsonify({
  339. "status_code": 500,
  340. "detail": "dataset not found",
  341. "data": {}
  342. })
  343. dataset = datasets[0]
  344. dataset_name = None
  345. if dataset:
  346. dataset_name = dataset['name']
  347. res.append(
  348. {'docId': content_chunk['doc_id'], 'content': content_chunk['text'],
  349. 'contentSummary': content_chunk['summary'], 'score': result['score'], 'datasetName': dataset_name})
  350. chat_classifier = ChatClassifier()
  351. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  352. data = {'results': res, 'chat_res': chat_res}
  353. return jsonify({'status_code': 200,
  354. 'detail': "success",
  355. 'data': data})