buleprint.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import traceback
  2. import uuid
  3. from quart import Blueprint, jsonify, request
  4. from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig, WEIGHT_MAP
  5. from applications.api import get_basic_embedding
  6. from applications.async_task import ChunkEmbeddingTask
  7. from applications.utils.milvus import MilvusSearch
  8. server_bp = Blueprint("api", __name__, url_prefix="/api")
  9. def server_routes(mysql_db, vector_db):
  10. @server_bp.route("/embed", methods=["POST"])
  11. async def embed():
  12. body = await request.get_json()
  13. text = body.get("text")
  14. model_name = body.get("model", DEFAULT_MODEL)
  15. if not LOCAL_MODEL_CONFIG.get(model_name):
  16. return jsonify({"error": "error model"})
  17. embedding = await get_basic_embedding(text, model_name)
  18. return jsonify({"embedding": embedding})
  19. @server_bp.route("/chunk", methods=["POST"])
  20. async def chunk():
  21. body = await request.get_json()
  22. text = body.get("text", "")
  23. text = text.strip()
  24. if not text:
  25. return jsonify({"error": "error text"})
  26. doc_id = f"doc-{uuid.uuid4()}"
  27. chunk_task = ChunkEmbeddingTask(mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id)
  28. doc_id = await chunk_task.deal(body)
  29. return jsonify({"doc_id": doc_id})
  30. @server_bp.route("/search", methods=["POST"])
  31. async def search():
  32. body = await request.get_json()
  33. search_type = body.get("search_type")
  34. if not search_type:
  35. return jsonify({"error": "missing search_type"}), 400
  36. searcher = MilvusSearch(vector_db)
  37. try:
  38. # 统一参数
  39. expr = body.get("expr")
  40. search_params = body.get("search_params") or {"metric_type": "COSINE", "params": {"ef": 64}}
  41. limit = body.get("limit", 50)
  42. query = body.get("query")
  43. async def by_vector():
  44. if not query:
  45. return {"error": "missing query"}
  46. field = body.get("field", "vector_text")
  47. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  48. return await searcher.vector_search(
  49. query_vec=query_vec,
  50. anns_field=field,
  51. expr=expr,
  52. search_params=search_params,
  53. limit=limit,
  54. )
  55. async def hybrid():
  56. if not query:
  57. return {"error": "missing query"}
  58. field = body.get("field", "vector_text")
  59. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  60. return await searcher.hybrid_search(
  61. query_vec=query_vec,
  62. anns_field=field,
  63. filters=body.get("filter_map"),
  64. limit=limit,
  65. )
  66. async def strategy():
  67. if not query:
  68. return {"error": "missing query"}
  69. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  70. return await searcher.search_by_strategy(
  71. query_vec=query_vec,
  72. weight_map=body.get("weight_map", WEIGHT_MAP),
  73. expr=expr,
  74. limit=limit,
  75. )
  76. # dispatch table
  77. handlers = {
  78. "by_vector": by_vector,
  79. "hybrid": hybrid,
  80. "strategy": strategy,
  81. }
  82. if search_type not in handlers:
  83. return jsonify({"error": "invalid search_type"}), 400
  84. result = await handlers[search_type]()
  85. return jsonify(result)
  86. except Exception as e:
  87. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  88. return server_bp