Quellcode durchsuchen

支持仅生成,不入query库

jihuaqiang vor 5 Monaten
Ursprung
Commit
08e8b19959
4 geänderte Dateien mit 25 neuen und 20 gelöschten Zeilen
  1. 19 15
      src/agent/query_agent.py
  2. 4 4
      src/database/models.py
  3. 1 0
      src/models/schemas.py
  4. 1 1
      src/tools/scheduler.py

+ 19 - 15
src/agent/query_agent.py

@@ -18,6 +18,7 @@ class AgentState(TypedDict):
     """Agent状态定义"""
     question: str
     task_id: int
+    need_store: int
     initial_queries: List[str]
     refined_queries: List[str]
     result_queries: List[Dict[str, str]]
@@ -311,20 +312,22 @@ class QueryGenerationAgent:
         ]
         state["result_queries"] = result_items
         
-        try:
-            url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
-            headers = {"Content-Type": "application/json"}
-            with httpx.Client() as client:
-                data_content = result_items
-                logger.info(f"查询词保存数据: {data_content}")
-                resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
-                resp1.raise_for_status()
-                logger.info(f"查询词保存结果: {resp1.text}")
-            logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
-        except httpx.HTTPError as e:
-            logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
-        except Exception as e:
-            logger.error(f"保存查询词时发生错误: {str(e)}")
+        # need_store=1 保存查询词
+        if state.get("need_store", 1) == 1:
+            try:
+                url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
+                headers = {"Content-Type": "application/json"}
+                with httpx.Client() as client:
+                    data_content = result_items
+                    logger.info(f"查询词保存数据: {data_content}")
+                    resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
+                    resp1.raise_for_status()
+                    logger.info(f"查询词保存结果: {resp1.text}")
+                logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
+            except httpx.HTTPError as e:
+                logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
+            except Exception as e:
+                logger.error(f"保存查询词时发生错误: {str(e)}")
         
         return state
     
@@ -438,7 +441,7 @@ class QueryGenerationAgent:
             raise ValueError("提取内容不是JSON数组")
         return data
 
-    async def generate_queries(self, question: str, task_id: int = 0, knowledge_type: str = "") -> List[str]:
+    async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> List[str]:
         """
         生成查询词的主入口
         
@@ -452,6 +455,7 @@ class QueryGenerationAgent:
         initial_state = {
             "question": question,
             "task_id": task_id,
+            "need_store": need_store,
             "initial_queries": [],
             "refined_queries": [],
             "result_queries": [],

+ 4 - 4
src/database/models.py

@@ -72,7 +72,7 @@ class QueryTaskDAO:
     def __init__(self):
         self.db_manager = get_db_manager()
     
-    def create_task(self, task_id: int, question: str, knowledge_type: str = "") -> bool:
+    def create_task(self, task_id: int, question: str, knowledge_type: str = "", need_store: int = 1) -> bool:
         """
         创建新的查询任务
         
@@ -86,8 +86,8 @@ class QueryTaskDAO:
         try:
             with self.db_manager.get_cursor() as cursor:
                 sql = """
-                INSERT INTO knowledge_suggest_query (task_id, question, status, knowledgeType, err_msg)
-                VALUES (%s, %s, %s, %s, %s)
+                INSERT INTO knowledge_suggest_query (task_id, question, status, knowledgeType, err_msg, needStore)
+                VALUES (%s, %s, %s, %s, %s, %s)
                 ON DUPLICATE KEY UPDATE
                 question = VALUES(question),
                 status = VALUES(status),
@@ -95,7 +95,7 @@ class QueryTaskDAO:
                 knowledgeType = VALUES(knowledgeType),
                 err_msg = NULL
                 """
-                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledge_type or "内容知识", None))
+                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledge_type or "内容知识", None, need_store))
                 return True
         except Exception as e:
             logger.error(f"创建任务失败: {e}")

+ 1 - 0
src/models/schemas.py

@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
 class QuestionRequest(BaseModel):
     """问题请求模型"""
     question: str = Field(..., description="用户提出的问题", min_length=1, max_length=1000)
+    need_store: int = Field(..., description="是否存储查询词", default=1)
 
 
 class QueryResponse(BaseModel):

+ 1 - 1
src/tools/scheduler.py

@@ -81,7 +81,7 @@ class TaskScheduler:
             
             try:
                 # 使用Agent生成查询词(Agent内部会在内容类型时直接失败并返回空)
-                queries = await self.agent.generate_queries(task.question, task.task_id)
+                queries = await self.agent.generate_queries(task.question, task.needStore, task.task_id)
                 
                 # 若为空,视为不支持内容类型,标记失败
                 if not queries: