buleprint.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import traceback
  2. import uuid
  3. from typing import Dict, Any
  4. from quart import Blueprint, jsonify, request
  5. from applications.config import (
  6. DEFAULT_MODEL,
  7. LOCAL_MODEL_CONFIG,
  8. ChunkerConfig,
  9. BASE_MILVUS_SEARCH_PARAMS,
  10. )
  11. from applications.resource import get_resource_manager
  12. from applications.api import get_basic_embedding
  13. from applications.api import get_img_embedding
  14. from applications.async_task import ChunkEmbeddingTask, DeleteTask
  15. from applications.search import HybridSearch
  16. server_bp = Blueprint("api", __name__, url_prefix="/api")
  17. @server_bp.route("/embed", methods=["POST"])
  18. async def embed():
  19. body = await request.get_json()
  20. text = body.get("text")
  21. model_name = body.get("model", DEFAULT_MODEL)
  22. if not LOCAL_MODEL_CONFIG.get(model_name):
  23. return jsonify({"error": "error model"})
  24. embedding = await get_basic_embedding(text, model_name)
  25. return jsonify({"embedding": embedding})
  26. @server_bp.route("/img_embed", methods=["POST"])
  27. async def img_embed():
  28. body = await request.get_json()
  29. url_list = body.get("url_list")
  30. if not url_list:
  31. return jsonify({"error": "error url_list"})
  32. embedding = await get_img_embedding(url_list)
  33. return jsonify(embedding)
  34. @server_bp.route("/delete", methods=["POST"])
  35. async def delete():
  36. body = await request.get_json()
  37. level = body.get("level")
  38. params = body.get("params")
  39. if not level or not params:
  40. return jsonify({"error": "error level or params"})
  41. resource = get_resource_manager()
  42. delete_task = DeleteTask(resource)
  43. response = await delete_task.deal(level, params)
  44. return jsonify(response)
  45. @server_bp.route("/chunk", methods=["POST"])
  46. async def chunk():
  47. body = await request.get_json()
  48. text = body.get("text", "")
  49. text = text.strip()
  50. if not text:
  51. return jsonify({"error": "error text"})
  52. resource = get_resource_manager()
  53. doc_id = f"doc-{uuid.uuid4()}"
  54. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  55. doc_id = await chunk_task.deal(body)
  56. return jsonify({"doc_id": doc_id})
  57. @server_bp.route("/search", methods=["POST"])
  58. async def search():
  59. """
  60. filters: Dict[str, Any], # 条件过滤
  61. query_vec: List[float], # query 的向量
  62. anns_field: str = "vector_text", # query指定的向量空间
  63. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  64. query_text: str = None, #是否通过 topic 倒排
  65. _source=False, # 是否返回元数据
  66. es_size: int = 10000, #es 第一层过滤数量
  67. sort_by: str = None, # 排序
  68. milvus_size: int = 10 # milvus粗排返回数量
  69. :return:
  70. """
  71. body = await request.get_json()
  72. # 解析数据
  73. search_type: str = body.get("search_type")
  74. filters: Dict[str, Any] = body.get("filters", {})
  75. anns_field: str = body.get("anns_field", "vector_text")
  76. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  77. query_text: str = body.get("query_text")
  78. _source: bool = body.get("_source", False)
  79. es_size: int = body.get("es_size", 10000)
  80. sort_by: str = body.get("sort_by")
  81. milvus_size: int = body.get("milvus", 20)
  82. limit: int = body.get("limit", 10)
  83. if not query_text:
  84. return jsonify({"error": "error query_text"})
  85. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  86. resource = get_resource_manager()
  87. search_engine = HybridSearch(
  88. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  89. )
  90. try:
  91. match search_type:
  92. case "base":
  93. response = await search_engine.base_vector_search(
  94. query_vec=query_vector,
  95. anns_field=anns_field,
  96. search_params=search_params,
  97. limit=limit,
  98. )
  99. return jsonify(response), 200
  100. case "hybrid":
  101. response = await search_engine.hybrid_search(
  102. filters=filters,
  103. query_vec=query_vector,
  104. anns_field=anns_field,
  105. search_params=search_params,
  106. es_size=es_size,
  107. sort_by=sort_by,
  108. milvus_size=milvus_size,
  109. )
  110. return jsonify(response), 200
  111. case "strategy":
  112. return jsonify({"error": "strategy not implemented"}), 405
  113. case _:
  114. return jsonify({"error": "error search_type"}), 200
  115. except Exception as e:
  116. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500