buleprint.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import traceback
  2. import uuid
  3. from quart import Blueprint, jsonify, request
  4. from applications.config import (
  5. DEFAULT_MODEL,
  6. LOCAL_MODEL_CONFIG,
  7. ChunkerConfig,
  8. WEIGHT_MAP,
  9. )
  10. from applications.api import get_basic_embedding
  11. from applications.api import get_img_embedding
  12. from applications.async_task import ChunkEmbeddingTask
  13. from applications.utils.milvus import MilvusSearch
  14. server_bp = Blueprint("api", __name__, url_prefix="/api")
  15. def server_routes(mysql_db, vector_db):
  16. @server_bp.route("/embed", methods=["POST"])
  17. async def embed():
  18. body = await request.get_json()
  19. text = body.get("text")
  20. model_name = body.get("model", DEFAULT_MODEL)
  21. if not LOCAL_MODEL_CONFIG.get(model_name):
  22. return jsonify({"error": "error model"})
  23. embedding = await get_basic_embedding(text, model_name)
  24. return jsonify({"embedding": embedding})
  25. @server_bp.route("/img_embed", methods=["POST"])
  26. async def img_embed():
  27. body = await request.get_json()
  28. url_list = body.get("url_list")
  29. if not url_list:
  30. return jsonify({"error": "error url_list"})
  31. embedding = await get_img_embedding(url_list)
  32. return jsonify(embedding)
  33. @server_bp.route("/chunk", methods=["POST"])
  34. async def chunk():
  35. body = await request.get_json()
  36. text = body.get("text", "")
  37. text = text.strip()
  38. if not text:
  39. return jsonify({"error": "error text"})
  40. doc_id = f"doc-{uuid.uuid4()}"
  41. chunk_task = ChunkEmbeddingTask(
  42. mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id
  43. )
  44. doc_id = await chunk_task.deal(body)
  45. return jsonify({"doc_id": doc_id})
  46. @server_bp.route("/search", methods=["POST"])
  47. async def search():
  48. body = await request.get_json()
  49. search_type = body.get("search_type")
  50. if not search_type:
  51. return jsonify({"error": "missing search_type"}), 400
  52. searcher = MilvusSearch(vector_db)
  53. try:
  54. # 统一参数
  55. expr = body.get("expr")
  56. search_params = body.get("search_params") or {
  57. "metric_type": "COSINE",
  58. "params": {"ef": 64},
  59. }
  60. limit = body.get("limit", 50)
  61. query = body.get("query")
  62. async def by_vector():
  63. if not query:
  64. return {"error": "missing query"}
  65. field = body.get("field", "vector_text")
  66. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  67. return await searcher.vector_search(
  68. query_vec=query_vec,
  69. anns_field=field,
  70. expr=expr,
  71. search_params=search_params,
  72. limit=limit,
  73. )
  74. async def hybrid():
  75. if not query:
  76. return {"error": "missing query"}
  77. field = body.get("field", "vector_text")
  78. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  79. return await searcher.hybrid_search(
  80. query_vec=query_vec,
  81. anns_field=field,
  82. filters=body.get("filter_map"),
  83. limit=limit,
  84. )
  85. async def strategy():
  86. if not query:
  87. return {"error": "missing query"}
  88. query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
  89. return await searcher.search_by_strategy(
  90. query_vec=query_vec,
  91. weight_map=body.get("weight_map", WEIGHT_MAP),
  92. expr=expr,
  93. limit=limit,
  94. )
  95. # dispatch table
  96. handlers = {
  97. "by_vector": by_vector,
  98. "hybrid": hybrid,
  99. "strategy": strategy,
  100. }
  101. if search_type not in handlers:
  102. return jsonify({"error": "invalid search_type"}), 400
  103. result = await handlers[search_type]()
  104. return jsonify(result)
  105. except Exception as e:
  106. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  107. return server_bp