Bladeren bron

dataset_id

jihuaqiang 2 weken geleden
bovenliggende
commit
a6a76c5db6
1 gewijzigde bestanden met toevoegingen van 21 en 4 verwijderingen
  1. 21 4
      agents/store_agent/agent.py

+ 21 - 4
agents/store_agent/agent.py

@@ -15,7 +15,6 @@ logger = get_logger('StoreAgent')
 
 
 CHUNK_API_URL = "http://61.48.133.26:8001/api/chunk"
-DATASET_ID = 14
 SCORE_THRESHOLD = 70
 
 
@@ -63,10 +62,27 @@ def _fetch_channel_by_content_id(content_id: str) -> Optional[str]:
     return None
 
 
-def _upload_chunk(text: str, query: str, channel: str = "", max_retries: int = 3, backoff_sec: float = 1.0) -> bool:
+def _resolve_dataset_id(request_id: str) -> int:
+    """根据 knowledge_query.knowledge_type 解析 dataset_id"""
+    try:
+        sql = "SELECT knowledge_type FROM knowledge_query WHERE request_id = %s ORDER BY id DESC LIMIT 1"
+        rows = MysqlHelper.get_values(sql, (request_id,))
+        if rows:
+            knowledge_type = rows[0][0] or ""
+            if knowledge_type == "工具知识":
+                return 12
+            if knowledge_type == "内容知识":
+                return 11
+    except Exception as e:
+        logger.warning(f"解析dataset_id失败,使用默认: requestId={request_id}, error={e}")
+    # 默认兜底
+    return 12
+
+
+def _upload_chunk(text: str, query: str, channel: str = "", dataset_id: int = 12, max_retries: int = 3, backoff_sec: float = 1.0) -> bool:
     # ext 需要是字符串 JSON
     payload = {
-        "dataset_id": DATASET_ID,
+        "dataset_id": dataset_id,
         "title": "",
         "text": text,
         "ext": json.dumps({"query": query, "channel": channel or ""}, ensure_ascii=False),
@@ -106,6 +122,7 @@ def execute_store_agent(request_id: str) -> Tuple[int, int]:
     try:
         query = _fetch_query(request_id)
         data_list = _fetch_extraction_data(request_id)
+        dataset_id = _resolve_dataset_id(request_id)
 
         total = len(data_list)
         success = 0
@@ -130,7 +147,7 @@ def execute_store_agent(request_id: str) -> Tuple[int, int]:
             except Exception as e:
                 logger.warning(f"获取channel失败: parsing_id={parsing_id}, error={e}")
 
-            ok = _upload_chunk(text, query, channel)
+            ok = _upload_chunk(text, query, channel, dataset_id)
             success += 1 if ok else 0
 
         if success == total: