buleprint.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import traceback
  2. import uuid
  3. from typing import Dict, Any
  4. from quart import Blueprint, jsonify, request
  5. from applications.config import (
  6. DEFAULT_MODEL,
  7. LOCAL_MODEL_CONFIG,
  8. ChunkerConfig,
  9. BASE_MILVUS_SEARCH_PARAMS,
  10. )
  11. from applications.resource import get_resource_manager
  12. from applications.api import get_basic_embedding
  13. from applications.api import get_img_embedding
  14. from applications.async_task import ChunkEmbeddingTask
  15. from applications.search import HybridSearch
  16. server_bp = Blueprint("api", __name__, url_prefix="/api")
  17. @server_bp.route("/embed", methods=["POST"])
  18. async def embed():
  19. body = await request.get_json()
  20. text = body.get("text")
  21. model_name = body.get("model", DEFAULT_MODEL)
  22. if not LOCAL_MODEL_CONFIG.get(model_name):
  23. return jsonify({"error": "error model"})
  24. embedding = await get_basic_embedding(text, model_name)
  25. return jsonify({"embedding": embedding})
  26. @server_bp.route("/img_embed", methods=["POST"])
  27. async def img_embed():
  28. body = await request.get_json()
  29. url_list = body.get("url_list")
  30. if not url_list:
  31. return jsonify({"error": "error url_list"})
  32. embedding = await get_img_embedding(url_list)
  33. return jsonify(embedding)
  34. @server_bp.route("/chunk", methods=["POST"])
  35. async def chunk():
  36. body = await request.get_json()
  37. text = body.get("text", "")
  38. text = text.strip()
  39. if not text:
  40. return jsonify({"error": "error text"})
  41. resource = get_resource_manager()
  42. doc_id = f"doc-{uuid.uuid4()}"
  43. chunk_task = ChunkEmbeddingTask(
  44. resource.mysql_client,
  45. resource.milvus_client,
  46. cfg=ChunkerConfig(),
  47. doc_id=doc_id,
  48. es_pool=resource.es_client,
  49. )
  50. doc_id = await chunk_task.deal(body)
  51. return jsonify({"doc_id": doc_id})
  52. @server_bp.route("/search", methods=["POST"])
  53. async def search():
  54. """
  55. filters: Dict[str, Any], # 条件过滤
  56. query_vec: List[float], # query 的向量
  57. anns_field: str = "vector_text", # query指定的向量空间
  58. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  59. query_text: str = None, #是否通过 topic 倒排
  60. _source=False, # 是否返回元数据
  61. es_size: int = 10000, #es 第一层过滤数量
  62. sort_by: str = None, # 排序
  63. milvus_size: int = 10 # milvus粗排返回数量
  64. :return:
  65. """
  66. body = await request.get_json()
  67. # 解析数据
  68. search_type: str = body.get("search_type")
  69. filters: Dict[str, Any] = body.get("filters", {})
  70. anns_field: str = body.get("anns_field", "vector_text")
  71. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  72. query_text: str = body.get("query_text")
  73. _source: bool = body.get("_source", False)
  74. es_size: int = body.get("es_size", 10000)
  75. sort_by: str = body.get("sort_by")
  76. milvus_size: int = body.get("milvus", 20)
  77. limit: int = body.get("limit", 10)
  78. if not query_text:
  79. return jsonify({"error": "error query_text"})
  80. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  81. resource = get_resource_manager()
  82. search_engine = HybridSearch(
  83. milvus_pool=resource.milvus_client, es_pool=resource.es_client
  84. )
  85. try:
  86. match search_type:
  87. case "base":
  88. response = await search_engine.base_vector_search(
  89. query_vec=query_vector,
  90. anns_field=anns_field,
  91. search_params=search_params,
  92. limit=limit,
  93. )
  94. return jsonify(response), 200
  95. case "hybrid":
  96. response = await search_engine.hybrid_search(
  97. filters=filters,
  98. query_vec=query_vector,
  99. anns_field=anns_field,
  100. search_params=search_params,
  101. es_size=es_size,
  102. sort_by=sort_by,
  103. milvus_size=milvus_size,
  104. )
  105. return jsonify(response), 200
  106. case "strategy":
  107. return jsonify({"error": "strategy not implemented"}), 405
  108. case _:
  109. return jsonify({"error": "error search_type"}), 200
  110. except Exception as e:
  111. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500