jihuaqiang vor 1 Woche
Ursprung
Commit
68d9a425a1
5 geänderte Dateien mit 54 neuen und 11 gelöschten Zeilen
  1. 45 4
      src/agent/query_agent.py
  2. 1 1
      src/api/main.py
  3. 6 5
      src/database/models.py
  4. 1 0
      src/models/schemas.py
  5. 1 1
      src/tools/scheduler.py

+ 45 - 4
src/agent/query_agent.py

@@ -3,6 +3,8 @@ from langgraph.graph import StateGraph, END
 from langchain_google_genai import ChatGoogleGenerativeAI
 from langchain.prompts import ChatPromptTemplate
 from langchain.schema import HumanMessage, SystemMessage
+import httpx
+import json
 
 from ..tools.query_tool import SuggestQueryTool
 from ..tools.prompts import QUERY_GENERATION_PROMPT, QUERY_REFINEMENT_PROMPT
@@ -51,6 +53,7 @@ class QueryGenerationAgent:
         workflow.add_node("generate_initial_queries", self._generate_initial_queries)
         workflow.add_node("refine_queries", self._refine_queries)
         workflow.add_node("validate_queries", self._validate_queries)
+        workflow.add_node("save_queries", self._save_queries)
         
         # 设置入口点
         workflow.set_entry_point("analyze_question")
@@ -59,7 +62,8 @@ class QueryGenerationAgent:
         workflow.add_edge("analyze_question", "generate_initial_queries")
         workflow.add_edge("generate_initial_queries", "refine_queries")
         workflow.add_edge("refine_queries", "validate_queries")
-        workflow.add_edge("validate_queries", END)
+        workflow.add_edge("validate_queries", "save_queries")
+        workflow.add_edge("save_queries", END)
         
         return workflow.compile()
     
@@ -164,14 +168,50 @@ class QueryGenerationAgent:
         state["refined_queries"] = validated_queries
         return state
     
-    async def generate_queries(self, question: str, task_id: int = 0) -> List[str]:
+    def _save_queries(self, state: AgentState) -> AgentState:
+        """保存查询词到外部接口节点"""
+        refined_queries = state["refined_queries"]
+        question = state["question"]
+        
+        if not refined_queries:
+            logger.warning("没有查询词需要保存")
+            return state
+        
+        # 调用外部接口保存查询词
+        try:
+            url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
+            headers = {"Content-Type": "application/json"}
+            
+            # 根据问题内容判断知识类型,这里可以根据实际需求调整逻辑
+            knowledge_type = state["knowledgeType"]  # 默认类型,可以根据问题内容动态判断
+            
+            data = {
+                "knowledgeType": knowledge_type,
+                "queryWords": refined_queries
+            }
+            
+            # 使用httpx发送请求
+            with httpx.Client() as client:
+                response = client.post(url, headers=headers, json=data, timeout=30)
+                response.raise_for_status()
+                
+            logger.info(f"查询词保存成功: {refined_queries}")
+            
+        except httpx.HTTPError as e:
+            logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
+        except Exception as e:
+            logger.error(f"保存查询词时发生错误: {str(e)}")
+        
+        return state
+    
+    async def generate_queries(self, question: str, task_id: int = 0, knowledgeType: str = "") -> List[str]:
         """
         生成查询词的主入口
         
         Args:
             question: 用户问题
             task_id: 任务ID
-            
+            knowledgeType: 知识类型
         Returns:
             生成的查询词列表
         """
@@ -181,7 +221,8 @@ class QueryGenerationAgent:
             "initial_queries": [],
             "refined_queries": [],
             "context": "",
-            "iteration_count": 0
+            "iteration_count": 0,
+            "knowledgeType": knowledgeType
         }
         
         try:

+ 1 - 1
src/api/main.py

@@ -144,7 +144,7 @@ async def generate_queries(request: QuestionRequest):
         task_id = int(time.time() * 1000)
         
         # 创建任务记录,状态设置为0(待执行)
-        task_dao.create_task(task_id, request.question)
+        task_dao.create_task(task_id, request.question, request.knowledgeType)
         logger.info(f"创建任务: {task_id},状态: 待执行")
         
         # 立即返回待执行状态

+ 6 - 5
src/database/models.py

@@ -66,7 +66,7 @@ class QueryTaskDAO:
     def __init__(self):
         self.db_manager = get_db_manager()
     
-    def create_task(self, task_id: int, question: str) -> bool:
+    def create_task(self, task_id: int, question: str, knowledgeType: str) -> bool:
         """
         创建新的查询任务
         
@@ -80,14 +80,15 @@ class QueryTaskDAO:
         try:
             with self.db_manager.get_cursor() as cursor:
                 sql = """
-                INSERT INTO knowledge_suggest_query (task_id, question, status)
-                VALUES (%s, %s, %s)
+                INSERT INTO knowledge_suggest_query (task_id, question, status, knowledgeType)
+                VALUES (%s, %s, %s, %s)
                 ON DUPLICATE KEY UPDATE
                 question = VALUES(question),
                 status = VALUES(status),
-                querys = NULL
+                querys = NULL,
+                knowledgeType = VALUES(knowledgeType)
                 """
-                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING))
+                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledgeType))
                 return True
         except Exception as e:
             logger.error(f"创建任务失败: {e}")

+ 1 - 0
src/models/schemas.py

@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
 class QuestionRequest(BaseModel):
     """问题请求模型"""
     question: str = Field(..., description="用户提出的问题", min_length=1, max_length=1000)
+    knowledgeType: str = Field(..., description="知识类型")
 
 
 class QueryResponse(BaseModel):

+ 1 - 1
src/tools/scheduler.py

@@ -81,7 +81,7 @@ class TaskScheduler:
             
             try:
                 # 使用Agent生成查询词
-                queries = await self.agent.generate_queries(task.question, task.task_id)
+                queries = await self.agent.generate_queries(task.question, task.task_id, task.knowledgeType)
                 
                 # 更新任务结果
                 success = self.task_dao.update_task_results(task.task_id, queries, QueryTaskStatus.SUCCESS)