buleprint.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import traceback
  2. import uuid
  3. from quart import Blueprint, jsonify, request
  4. from applications.config import (
  5. DEFAULT_MODEL,
  6. LOCAL_MODEL_CONFIG,
  7. ChunkerConfig,
  8. WEIGHT_MAP,
  9. )
  10. from applications.api import get_basic_embedding
  11. from applications.async_task import ChunkEmbeddingTask
  12. from applications.utils.milvus import MilvusSearch
  13. server_bp = Blueprint("api", __name__, url_prefix="/api")
  14. def server_routes(mysql_db, vector_db):
  15. @server_bp.route("/embed", methods=["POST"])
  16. async def embed():
  17. body = await request.get_json()
  18. text = body.get("text")
  19. model_name = body.get("model", DEFAULT_MODEL)
  20. if not LOCAL_MODEL_CONFIG.get(model_name):
  21. return jsonify({"error": "error model"})
  22. embedding = await get_basic_embedding(text, model_name)
  23. return jsonify({"embedding": embedding})
  24. @server_bp.route("/chunk", methods=["POST"])
  25. async def chunk():
  26. body = await request.get_json()
  27. text = body.get("text", "")
  28. text = text.strip()
  29. if not text:
  30. return jsonify({"error": "error text"})
  31. doc_id = f"doc-{uuid.uuid4()}"
  32. chunk_task = ChunkEmbeddingTask(
  33. mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id
  34. )
  35. doc_id = await chunk_task.deal(body)
  36. return jsonify({"doc_id": doc_id})
  37. @server_bp.route("/search", methods=["POST"])
  38. async def search():
  39. body = await request.get_json()
  40. search_type = body.get("search_type")
  41. if not search_type:
  42. return jsonify({"error": "missing search_type"}), 400
  43. searcher = MilvusSearch(vector_db)
  44. try:
  45. # 统一参数
  46. expr = body.get("expr")
  47. search_params = body.get("search_params") or {
  48. "metric_type": "COSINE",
  49. "params": {"ef": 64},
  50. }
  51. limit = body.get("limit", 50)
  52. query = body.get("query")
  53. async def by_vector():
  54. if not query:
  55. return {"error": "missing query"}
  56. field = body.get("field", "vector_text")
  57. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  58. return await searcher.vector_search(
  59. query_vec=query_vec,
  60. anns_field=field,
  61. expr=expr,
  62. search_params=search_params,
  63. limit=limit,
  64. )
  65. async def hybrid():
  66. if not query:
  67. return {"error": "missing query"}
  68. field = body.get("field", "vector_text")
  69. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  70. return await searcher.hybrid_search(
  71. query_vec=query_vec,
  72. anns_field=field,
  73. filters=body.get("filter_map"),
  74. limit=limit,
  75. )
  76. async def strategy():
  77. if not query:
  78. return {"error": "missing query"}
  79. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  80. return await searcher.search_by_strategy(
  81. query_vec=query_vec,
  82. weight_map=body.get("weight_map", WEIGHT_MAP),
  83. expr=expr,
  84. limit=limit,
  85. )
  86. # dispatch table
  87. handlers = {
  88. "by_vector": by_vector,
  89. "hybrid": hybrid,
  90. "strategy": strategy,
  91. }
  92. if search_type not in handlers:
  93. return jsonify({"error": "invalid search_type"}), 400
  94. result = await handlers[search_type]()
  95. return jsonify(result)
  96. except Exception as e:
  97. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  98. return server_bp