buleprint.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import uuid
  2. from quart import Blueprint, jsonify, request
  3. from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig, WEIGHT_MAP
  4. from applications.api import get_basic_embedding
  5. from applications.async_task import ChunkEmbeddingTask
  6. from applications.utils.milvus import MilvusSearcher
  7. server_bp = Blueprint("api", __name__, url_prefix="/api")
  8. def server_routes(mysql_db, vector_db):
  9. @server_bp.route("/embed", methods=["POST"])
  10. async def embed():
  11. body = await request.get_json()
  12. text = body.get("text")
  13. model_name = body.get("model", DEFAULT_MODEL)
  14. if not LOCAL_MODEL_CONFIG.get(model_name):
  15. return jsonify({"error": "error model"})
  16. embedding = await get_basic_embedding(text, model_name)
  17. return jsonify({"embedding": embedding})
  18. @server_bp.route("/chunk", methods=["POST"])
  19. async def chunk():
  20. body = await request.get_json()
  21. text = body.get("text", "")
  22. text = text.strip()
  23. if not text:
  24. return jsonify({"error": "error text"})
  25. doc_id = f"doc-{uuid.uuid4()}"
  26. chunk_task = ChunkEmbeddingTask(mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id)
  27. doc_id = await chunk_task.deal(body)
  28. return jsonify({"doc_id": doc_id})
  29. @server_bp.route("/search", methods=["POST"])
  30. async def search():
  31. body = await request.get_json()
  32. search_type = body.get("search_type")
  33. if not search_type:
  34. return jsonify({"error": "missing search_type"}), 400
  35. searcher = MilvusSearcher(vector_db)
  36. try:
  37. # 统一参数
  38. expr = body.get("expr")
  39. search_params = body.get("search_params") or {"metric_type": "COSINE", "params": {"ef": 64}}
  40. limit = body.get("limit", 5)
  41. query = body.get("query")
  42. # 定义不同搜索策略
  43. async def by_pk_id():
  44. pk_id = body.get("id")
  45. if not pk_id:
  46. return {"error": "missing id"}
  47. return await searcher.get_by_id(pk_id)
  48. async def by_doc_id():
  49. doc_id, chunk_id = body.get("doc_id"), body.get("chunk_id")
  50. if not doc_id or chunk_id is None:
  51. return {"error": "missing doc_id or chunk_id"}
  52. return await searcher.get_by_doc_and_chunk(doc_id, chunk_id)
  53. async def by_vector():
  54. if not query:
  55. return {"error": "missing query"}
  56. field = body.get("field", "vector_text")
  57. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  58. return await searcher.vector_search(
  59. query_vec=query_vec,
  60. anns_field=field,
  61. expr=expr,
  62. search_params=search_params,
  63. limit=limit,
  64. )
  65. async def by_filter():
  66. filter_map = body.get("filter_map")
  67. if not filter_map:
  68. return {"error": "missing filter_map"}
  69. return await searcher.filter_search(filter_map)
  70. async def hybrid():
  71. if not query:
  72. return {"error": "missing query"}
  73. field = body.get("field", "vector_text")
  74. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  75. return await searcher.hybrid_search(
  76. query_vec=query_vec,
  77. anns_field=field,
  78. filters=body.get("filter_map"),
  79. limit=limit,
  80. )
  81. async def strategy():
  82. if not query:
  83. return {"error": "missing query"}
  84. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  85. return await searcher.search_by_strategy(
  86. query_vec=query_vec,
  87. weight_map=body.get("weight_map", WEIGHT_MAP),
  88. expr=expr,
  89. limit=limit,
  90. )
  91. # dispatch table
  92. handlers = {
  93. "pk_id": by_pk_id,
  94. "by_doc_id": by_doc_id,
  95. "by_vector": by_vector,
  96. "by_filter": by_filter,
  97. "hybrid": hybrid,
  98. "strategy": strategy,
  99. }
  100. if search_type not in handlers:
  101. return jsonify({"error": "invalid search_type"}), 400
  102. result = await handlers[search_type]()
  103. return jsonify(result)
  104. except Exception as e:
  105. return jsonify({"error": str(e)}), 500
  106. return server_bp