|
@@ -1,4 +1,4 @@
|
|
-from typing import List, Dict, Any, TypedDict, Annotated
|
|
|
|
|
|
+from typing import List, Dict, Any, TypedDict
|
|
from langgraph.graph import StateGraph, END
|
|
from langgraph.graph import StateGraph, END
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain.prompts import ChatPromptTemplate
|
|
from langchain.prompts import ChatPromptTemplate
|
|
@@ -6,8 +6,7 @@ from langchain.schema import HumanMessage, SystemMessage
|
|
import httpx
|
|
import httpx
|
|
import json
|
|
import json
|
|
|
|
|
|
-from ..tools.query_tool import SuggestQueryTool
|
|
|
|
-from ..tools.prompts import QUERY_GENERATION_PROMPT, QUERY_REFINEMENT_PROMPT
|
|
|
|
|
|
+from ..tools.prompts import STRUCTURED_TOOL_DEMAND_PROMPT
|
|
from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
|
|
from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
|
|
|
|
|
|
|
|
|
|
@@ -18,8 +17,7 @@ class AgentState(TypedDict):
|
|
initial_queries: List[str]
|
|
initial_queries: List[str]
|
|
refined_queries: List[str]
|
|
refined_queries: List[str]
|
|
result_queries: List[Dict[str, str]]
|
|
result_queries: List[Dict[str, str]]
|
|
- context: str
|
|
|
|
- iteration_count: int
|
|
|
|
|
|
+ knowledgeType: str
|
|
|
|
|
|
|
|
|
|
class QueryGenerationAgent:
|
|
class QueryGenerationAgent:
|
|
@@ -39,7 +37,6 @@ class QueryGenerationAgent:
|
|
temperature=0.7
|
|
temperature=0.7
|
|
)
|
|
)
|
|
|
|
|
|
- self.query_tool = SuggestQueryTool()
|
|
|
|
self.task_dao = QueryTaskDAO()
|
|
self.task_dao = QueryTaskDAO()
|
|
|
|
|
|
# 创建状态图
|
|
# 创建状态图
|
|
@@ -49,166 +46,89 @@ class QueryGenerationAgent:
|
|
"""创建LangGraph状态图"""
|
|
"""创建LangGraph状态图"""
|
|
workflow = StateGraph(AgentState)
|
|
workflow = StateGraph(AgentState)
|
|
|
|
|
|
- # 添加节点
|
|
|
|
- workflow.add_node("analyze_question", self._analyze_question)
|
|
|
|
|
|
+ # 添加节点(仅保留 生成 与 保存)
|
|
workflow.add_node("generate_initial_queries", self._generate_initial_queries)
|
|
workflow.add_node("generate_initial_queries", self._generate_initial_queries)
|
|
- workflow.add_node("refine_queries", self._refine_queries)
|
|
|
|
- workflow.add_node("validate_queries", self._validate_queries)
|
|
|
|
- workflow.add_node("classify_queries", self._classify_queries)
|
|
|
|
workflow.add_node("save_queries", self._save_queries)
|
|
workflow.add_node("save_queries", self._save_queries)
|
|
|
|
|
|
# 设置入口点
|
|
# 设置入口点
|
|
- workflow.set_entry_point("analyze_question")
|
|
|
|
|
|
+ workflow.set_entry_point("generate_initial_queries")
|
|
|
|
|
|
# 添加边
|
|
# 添加边
|
|
- workflow.add_edge("analyze_question", "generate_initial_queries")
|
|
|
|
- workflow.add_edge("generate_initial_queries", "refine_queries")
|
|
|
|
- workflow.add_edge("refine_queries", "validate_queries")
|
|
|
|
- workflow.add_edge("validate_queries", "classify_queries")
|
|
|
|
- workflow.add_edge("classify_queries", "save_queries")
|
|
|
|
|
|
+ workflow.add_edge("generate_initial_queries", "save_queries")
|
|
workflow.add_edge("save_queries", END)
|
|
workflow.add_edge("save_queries", END)
|
|
|
|
|
|
return workflow.compile()
|
|
return workflow.compile()
|
|
|
|
|
|
- def _analyze_question(self, state: AgentState) -> AgentState:
|
|
|
|
- """分析问题节点"""
|
|
|
|
- question = state["question"]
|
|
|
|
-
|
|
|
|
- # 分析问题的复杂度和类型
|
|
|
|
- analysis_prompt = ChatPromptTemplate.from_messages([
|
|
|
|
- SystemMessage(content="你是一个问题分析专家。请分析用户问题的类型和复杂度。"),
|
|
|
|
- HumanMessage(content=f"请分析这个问题:{question}\n\n分析要点:1.问题类型 2.复杂度 3.关键词 4.需要的查询角度")
|
|
|
|
- ])
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- response = self.llm.invoke(analysis_prompt.format_messages())
|
|
|
|
- logger.info(f"问题分析结果: {response.content}")
|
|
|
|
- context = response.content
|
|
|
|
- except Exception as e:
|
|
|
|
- context = f"问题分析失败: {str(e)}"
|
|
|
|
-
|
|
|
|
- state["context"] = context
|
|
|
|
- state["iteration_count"] = 0
|
|
|
|
-
|
|
|
|
- return state
|
|
|
|
-
|
|
|
|
def _generate_initial_queries(self, state: AgentState) -> AgentState:
|
|
def _generate_initial_queries(self, state: AgentState) -> AgentState:
|
|
- """生成初始查询词节点"""
|
|
|
|
- question = state["question"]
|
|
|
|
- task_id = state["task_id"]
|
|
|
|
- context = state.get("context", "")
|
|
|
|
-
|
|
|
|
- # 使用工具生成查询词
|
|
|
|
- try:
|
|
|
|
- initial_queries = self.query_tool._run(question, context, task_id)
|
|
|
|
- except Exception as e:
|
|
|
|
- # 如果工具失败,使用LLM生成
|
|
|
|
- prompt = ChatPromptTemplate.from_messages([
|
|
|
|
- SystemMessage(content=QUERY_GENERATION_PROMPT),
|
|
|
|
- HumanMessage(content=question)
|
|
|
|
- ])
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- response = self.llm.invoke(prompt.format_messages())
|
|
|
|
- queries_text = response.content
|
|
|
|
- initial_queries = [q.strip() for q in queries_text.split('\n') if q.strip()]
|
|
|
|
- except Exception:
|
|
|
|
- initial_queries = [question] # 降级处理
|
|
|
|
-
|
|
|
|
- state["initial_queries"] = initial_queries
|
|
|
|
- return state
|
|
|
|
-
|
|
|
|
- def _refine_queries(self, state: AgentState) -> AgentState:
|
|
|
|
- """优化查询词节点"""
|
|
|
|
|
|
+ """生成 refined_queries(从结构化JSON中聚合三类关键词)"""
|
|
question = state["question"]
|
|
question = state["question"]
|
|
- initial_queries = state["initial_queries"]
|
|
|
|
-
|
|
|
|
- if not initial_queries:
|
|
|
|
- state["refined_queries"] = [question]
|
|
|
|
- return state
|
|
|
|
-
|
|
|
|
- # 使用LLM优化查询词
|
|
|
|
- queries_text = '\n'.join(initial_queries)
|
|
|
|
|
|
+ # 使用新的结构化系统提示
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
- SystemMessage(content=QUERY_REFINEMENT_PROMPT),
|
|
|
|
- HumanMessage(content=f"问题:{question}\n查询词:{queries_text}")
|
|
|
|
|
|
+ SystemMessage(content=STRUCTURED_TOOL_DEMAND_PROMPT),
|
|
|
|
+ HumanMessage(content=question)
|
|
])
|
|
])
|
|
-
|
|
|
|
try:
|
|
try:
|
|
response = self.llm.invoke(prompt.format_messages())
|
|
response = self.llm.invoke(prompt.format_messages())
|
|
- logger.info(f"查询词优化结果: {response.content}")
|
|
|
|
- refined_text = response.content
|
|
|
|
- refined_queries = [q.strip() for q in refined_text.split('\n') if q.strip()]
|
|
|
|
|
|
+ text = (response.content or "").strip()
|
|
|
|
+ # 解析严格的JSON数组;若失败,尝试从文本中提取
|
|
|
|
+ try:
|
|
|
|
+ data = json.loads(text)
|
|
|
|
+ except Exception:
|
|
|
|
+ data = self._extract_json_array_from_text(text)
|
|
|
|
+ logger.info(f"需求分析结果: {data}")
|
|
|
|
+ aggregated: List[str] = []
|
|
|
|
+ for item in data:
|
|
|
|
+ ek = (item or {}).get("expanded_keywords", {})
|
|
|
|
+ g = ek.get("general_discovery_queries", []) or []
|
|
|
|
+ t = ek.get("themed_function_queries", []) or []
|
|
|
|
+ h = ek.get("how_to_use_queries", []) or []
|
|
|
|
+ for q in [*g, *t, *h]:
|
|
|
|
+ q_str = str(q).strip()
|
|
|
|
+ if q_str:
|
|
|
|
+ aggregated.append(q_str)
|
|
|
|
+ # 去重,保持顺序
|
|
|
|
+ seen = set()
|
|
|
|
+ deduped: List[str] = []
|
|
|
|
+ for q in aggregated:
|
|
|
|
+ if q not in seen:
|
|
|
|
+ seen.add(q)
|
|
|
|
+ deduped.append(q)
|
|
|
|
+ state["initial_queries"] = deduped
|
|
|
|
+ state["refined_queries"] = deduped
|
|
except Exception as e:
|
|
except Exception as e:
|
|
- # 如果优化失败,使用原始查询词
|
|
|
|
- refined_queries = initial_queries
|
|
|
|
-
|
|
|
|
- state["refined_queries"] = refined_queries
|
|
|
|
|
|
+ logger.warning(f"结构化需求解析失败,降级为原始问题: {e}")
|
|
|
|
+ state["initial_queries"] = [question]
|
|
|
|
+ state["refined_queries"] = [question]
|
|
return state
|
|
return state
|
|
|
|
|
|
- def _validate_queries(self, state: AgentState) -> AgentState:
|
|
|
|
- """验证查询词节点"""
|
|
|
|
- refined_queries = state["refined_queries"]
|
|
|
|
-
|
|
|
|
- # 基本验证:去重、过滤空字符串、限制长度
|
|
|
|
- validated_queries = []
|
|
|
|
- seen = set()
|
|
|
|
-
|
|
|
|
- for query in refined_queries:
|
|
|
|
- if query and len(query.strip()) > 0 and len(query.strip()) < 100:
|
|
|
|
- if query.strip() not in seen:
|
|
|
|
- validated_queries.append(query.strip())
|
|
|
|
- seen.add(query.strip())
|
|
|
|
-
|
|
|
|
- # 限制最终数量
|
|
|
|
- if len(validated_queries) > 10:
|
|
|
|
- validated_queries = validated_queries[:10]
|
|
|
|
-
|
|
|
|
- # 确保至少有一个查询词
|
|
|
|
- if not validated_queries:
|
|
|
|
- validated_queries = [state["question"]]
|
|
|
|
- logger.info(f"查询词验证结果: {validated_queries}")
|
|
|
|
- state["refined_queries"] = validated_queries
|
|
|
|
- return state
|
|
|
|
-
|
|
|
|
- def _classify_queries(self, state: AgentState) -> AgentState:
|
|
|
|
- """推测每个查询词的知识类型并写入result_queries"""
|
|
|
|
- refined_queries = state.get("refined_queries", [])
|
|
|
|
- # 使用大模型进行分类
|
|
|
|
- result_items: List[Dict[str, str]] = self._classify_with_llm(refined_queries)
|
|
|
|
- state["result_queries"] = result_items
|
|
|
|
- return state
|
|
|
|
|
|
+ # 删除 refine/validate/classify 节点
|
|
|
|
|
|
def _save_queries(self, state: AgentState) -> AgentState:
|
|
def _save_queries(self, state: AgentState) -> AgentState:
|
|
"""保存查询词到外部接口节点"""
|
|
"""保存查询词到外部接口节点"""
|
|
- refined_queries = state["refined_queries"]
|
|
|
|
- question = state["question"]
|
|
|
|
|
|
+ refined_queries = state.get("refined_queries", [])
|
|
|
|
+ question = state.get("question", "")
|
|
|
|
+ knowledge_type = state.get("knowledgeType", "") or "内容知识"
|
|
|
|
|
|
if not refined_queries:
|
|
if not refined_queries:
|
|
logger.warning("没有查询词需要保存")
|
|
logger.warning("没有查询词需要保存")
|
|
return state
|
|
return state
|
|
|
|
|
|
- # 调用外部接口保存查询词(按类型分组)
|
|
|
|
|
|
+ # 合并 knowledgeType 与每个查询词,形成提交数据
|
|
|
|
+ result_items: List[Dict[str, str]] = [
|
|
|
|
+ {"query": q, "knowledgeType": knowledge_type} for q in refined_queries
|
|
|
|
+ ]
|
|
|
|
+ state["result_queries"] = result_items
|
|
|
|
+
|
|
try:
|
|
try:
|
|
url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
|
|
url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
|
|
headers = {"Content-Type": "application/json"}
|
|
headers = {"Content-Type": "application/json"}
|
|
-
|
|
|
|
- # 仅使用前一步的分类结果,不做即时分类
|
|
|
|
- result_items: List[Dict[str, str]] = state.get("result_queries", [])
|
|
|
|
- if not result_items:
|
|
|
|
- logger.warning("缺少分类结果result_queries,跳过外部提交")
|
|
|
|
- return state
|
|
|
|
-
|
|
|
|
- if result_items:
|
|
|
|
- with httpx.Client() as client:
|
|
|
|
- data_content = result_items
|
|
|
|
- logger.info(f"查询词保存数据: {data_content}")
|
|
|
|
- resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
|
|
|
|
- resp1.raise_for_status()
|
|
|
|
- logger.info(f"查询词保存结果: {resp1.text}")
|
|
|
|
-
|
|
|
|
|
|
+ with httpx.Client() as client:
|
|
|
|
+ data_content = result_items
|
|
|
|
+ logger.info(f"查询词保存数据: {data_content}")
|
|
|
|
+ # resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
|
|
|
|
+ # resp1.raise_for_status()
|
|
|
|
+ # logger.info(f"查询词保存结果: {resp1.text}")
|
|
logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
|
|
logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
|
|
-
|
|
|
|
except httpx.HTTPError as e:
|
|
except httpx.HTTPError as e:
|
|
logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
|
|
logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -303,7 +223,7 @@ class QueryGenerationAgent:
|
|
raise ValueError("提取内容不是JSON数组")
|
|
raise ValueError("提取内容不是JSON数组")
|
|
return data
|
|
return data
|
|
|
|
|
|
- async def generate_queries(self, question: str, task_id: int = 0) -> List[str]:
|
|
|
|
|
|
+ async def generate_queries(self, question: str, task_id: int = 0, knowledge_type: str = "") -> List[str]:
|
|
"""
|
|
"""
|
|
生成查询词的主入口
|
|
生成查询词的主入口
|
|
|
|
|
|
@@ -319,8 +239,7 @@ class QueryGenerationAgent:
|
|
"initial_queries": [],
|
|
"initial_queries": [],
|
|
"refined_queries": [],
|
|
"refined_queries": [],
|
|
"result_queries": [],
|
|
"result_queries": [],
|
|
- "context": "",
|
|
|
|
- "iteration_count": 0
|
|
|
|
|
|
+ "knowledgeType": knowledge_type or "内容知识"
|
|
}
|
|
}
|
|
|
|
|
|
try:
|
|
try:
|
|
@@ -332,3 +251,4 @@ class QueryGenerationAgent:
|
|
self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
|
|
self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
|
|
# 降级处理:返回原始问题
|
|
# 降级处理:返回原始问题
|
|
return [question]
|
|
return [question]
|
|
|
|
+
|