|
@@ -3,6 +3,8 @@ from langgraph.graph import StateGraph, END
|
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
from langchain.prompts import ChatPromptTemplate
|
|
|
from langchain.schema import HumanMessage, SystemMessage
|
|
|
+import httpx
|
|
|
+import json
|
|
|
|
|
|
from ..tools.query_tool import SuggestQueryTool
|
|
|
from ..tools.prompts import QUERY_GENERATION_PROMPT, QUERY_REFINEMENT_PROMPT
|
|
@@ -51,6 +53,7 @@ class QueryGenerationAgent:
|
|
|
workflow.add_node("generate_initial_queries", self._generate_initial_queries)
|
|
|
workflow.add_node("refine_queries", self._refine_queries)
|
|
|
workflow.add_node("validate_queries", self._validate_queries)
|
|
|
+ workflow.add_node("save_queries", self._save_queries)
|
|
|
|
|
|
# 设置入口点
|
|
|
workflow.set_entry_point("analyze_question")
|
|
@@ -59,7 +62,8 @@ class QueryGenerationAgent:
|
|
|
workflow.add_edge("analyze_question", "generate_initial_queries")
|
|
|
workflow.add_edge("generate_initial_queries", "refine_queries")
|
|
|
workflow.add_edge("refine_queries", "validate_queries")
|
|
|
- workflow.add_edge("validate_queries", END)
|
|
|
+ workflow.add_edge("validate_queries", "save_queries")
|
|
|
+ workflow.add_edge("save_queries", END)
|
|
|
|
|
|
return workflow.compile()
|
|
|
|
|
@@ -164,14 +168,50 @@ class QueryGenerationAgent:
|
|
|
state["refined_queries"] = validated_queries
|
|
|
return state
|
|
|
|
|
|
- async def generate_queries(self, question: str, task_id: int = 0) -> List[str]:
|
|
|
+ def _save_queries(self, state: AgentState) -> AgentState:
|
|
|
+ """保存查询词到外部接口节点"""
|
|
|
+ refined_queries = state["refined_queries"]
|
|
|
+ question = state["question"]
|
|
|
+
|
|
|
+ if not refined_queries:
|
|
|
+ logger.warning("没有查询词需要保存")
|
|
|
+ return state
|
|
|
+
|
|
|
+ # 调用外部接口保存查询词
|
|
|
+ try:
|
|
|
+ url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
|
|
|
+ headers = {"Content-Type": "application/json"}
|
|
|
+
|
|
|
+ # 根据问题内容判断知识类型,这里可以根据实际需求调整逻辑
|
|
|
+ knowledge_type = state["knowledgeType"] # 默认类型,可以根据问题内容动态判断
|
|
|
+
|
|
|
+ data = {
|
|
|
+ "knowledgeType": knowledge_type,
|
|
|
+ "queryWords": refined_queries
|
|
|
+ }
|
|
|
+
|
|
|
+ # 使用httpx发送请求
|
|
|
+ with httpx.Client() as client:
|
|
|
+ response = client.post(url, headers=headers, json=data, timeout=30)
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ logger.info(f"查询词保存成功: {refined_queries}")
|
|
|
+
|
|
|
+ except httpx.HTTPError as e:
|
|
|
+ logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"保存查询词时发生错误: {str(e)}")
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ async def generate_queries(self, question: str, task_id: int = 0, knowledgeType: str = "") -> List[str]:
|
|
|
"""
|
|
|
生成查询词的主入口
|
|
|
|
|
|
Args:
|
|
|
question: 用户问题
|
|
|
task_id: 任务ID
|
|
|
-
|
|
|
+ knowledgeType: 知识类型
|
|
|
Returns:
|
|
|
生成的查询词列表
|
|
|
"""
|
|
@@ -181,7 +221,8 @@ class QueryGenerationAgent:
|
|
|
"initial_queries": [],
|
|
|
"refined_queries": [],
|
|
|
"context": "",
|
|
|
- "iteration_count": 0
|
|
|
+ "iteration_count": 0,
|
|
|
+ "knowledgeType": knowledgeType
|
|
|
}
|
|
|
|
|
|
try:
|