query_agent.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. from typing import List, Dict, Any, TypedDict, Annotated
  2. from langgraph.graph import StateGraph, END
  3. from langchain_google_genai import ChatGoogleGenerativeAI
  4. from langchain.prompts import ChatPromptTemplate
  5. from langchain.schema import HumanMessage, SystemMessage
  6. import httpx
  7. import json
  8. from ..tools.query_tool import SuggestQueryTool
  9. from ..tools.prompts import QUERY_GENERATION_PROMPT, QUERY_REFINEMENT_PROMPT
  10. from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
  11. class AgentState(TypedDict):
  12. """Agent状态定义"""
  13. question: str
  14. task_id: int
  15. initial_queries: List[str]
  16. refined_queries: List[str]
  17. result_queries: List[Dict[str, str]]
  18. context: str
  19. iteration_count: int
  20. class QueryGenerationAgent:
  21. """查询词生成Agent"""
  22. def __init__(self, gemini_api_key: str, model_name: str = "gemini-1.5-pro"):
  23. """
  24. 初始化Agent
  25. Args:
  26. gemini_api_key: Gemini API密钥
  27. model_name: 使用的模型名称
  28. """
  29. self.llm = ChatGoogleGenerativeAI(
  30. google_api_key=gemini_api_key,
  31. model=model_name,
  32. temperature=0.7
  33. )
  34. self.query_tool = SuggestQueryTool()
  35. self.task_dao = QueryTaskDAO()
  36. # 创建状态图
  37. self.graph = self._create_graph()
  38. def _create_graph(self) -> StateGraph:
  39. """创建LangGraph状态图"""
  40. workflow = StateGraph(AgentState)
  41. # 添加节点
  42. workflow.add_node("analyze_question", self._analyze_question)
  43. workflow.add_node("generate_initial_queries", self._generate_initial_queries)
  44. workflow.add_node("refine_queries", self._refine_queries)
  45. workflow.add_node("validate_queries", self._validate_queries)
  46. workflow.add_node("classify_queries", self._classify_queries)
  47. workflow.add_node("save_queries", self._save_queries)
  48. # 设置入口点
  49. workflow.set_entry_point("analyze_question")
  50. # 添加边
  51. workflow.add_edge("analyze_question", "generate_initial_queries")
  52. workflow.add_edge("generate_initial_queries", "refine_queries")
  53. workflow.add_edge("refine_queries", "validate_queries")
  54. workflow.add_edge("validate_queries", "classify_queries")
  55. workflow.add_edge("classify_queries", "save_queries")
  56. workflow.add_edge("save_queries", END)
  57. return workflow.compile()
  58. def _analyze_question(self, state: AgentState) -> AgentState:
  59. """分析问题节点"""
  60. question = state["question"]
  61. # 分析问题的复杂度和类型
  62. analysis_prompt = ChatPromptTemplate.from_messages([
  63. SystemMessage(content="你是一个问题分析专家。请分析用户问题的类型和复杂度。"),
  64. HumanMessage(content=f"请分析这个问题:{question}\n\n分析要点:1.问题类型 2.复杂度 3.关键词 4.需要的查询角度")
  65. ])
  66. try:
  67. response = self.llm.invoke(analysis_prompt.format_messages())
  68. logger.info(f"问题分析结果: {response.content}")
  69. context = response.content
  70. except Exception as e:
  71. context = f"问题分析失败: {str(e)}"
  72. state["context"] = context
  73. state["iteration_count"] = 0
  74. return state
  75. def _generate_initial_queries(self, state: AgentState) -> AgentState:
  76. """生成初始查询词节点"""
  77. question = state["question"]
  78. task_id = state["task_id"]
  79. context = state.get("context", "")
  80. # 使用工具生成查询词
  81. try:
  82. initial_queries = self.query_tool._run(question, context, task_id)
  83. except Exception as e:
  84. # 如果工具失败,使用LLM生成
  85. prompt = ChatPromptTemplate.from_messages([
  86. SystemMessage(content=QUERY_GENERATION_PROMPT),
  87. HumanMessage(content=question)
  88. ])
  89. try:
  90. response = self.llm.invoke(prompt.format_messages())
  91. queries_text = response.content
  92. initial_queries = [q.strip() for q in queries_text.split('\n') if q.strip()]
  93. except Exception:
  94. initial_queries = [question] # 降级处理
  95. state["initial_queries"] = initial_queries
  96. return state
  97. def _refine_queries(self, state: AgentState) -> AgentState:
  98. """优化查询词节点"""
  99. question = state["question"]
  100. initial_queries = state["initial_queries"]
  101. if not initial_queries:
  102. state["refined_queries"] = [question]
  103. return state
  104. # 使用LLM优化查询词
  105. queries_text = '\n'.join(initial_queries)
  106. prompt = ChatPromptTemplate.from_messages([
  107. SystemMessage(content=QUERY_REFINEMENT_PROMPT),
  108. HumanMessage(content=f"问题:{question}\n查询词:{queries_text}")
  109. ])
  110. try:
  111. response = self.llm.invoke(prompt.format_messages())
  112. logger.info(f"查询词优化结果: {response.content}")
  113. refined_text = response.content
  114. refined_queries = [q.strip() for q in refined_text.split('\n') if q.strip()]
  115. except Exception as e:
  116. # 如果优化失败,使用原始查询词
  117. refined_queries = initial_queries
  118. state["refined_queries"] = refined_queries
  119. return state
  120. def _validate_queries(self, state: AgentState) -> AgentState:
  121. """验证查询词节点"""
  122. refined_queries = state["refined_queries"]
  123. # 基本验证:去重、过滤空字符串、限制长度
  124. validated_queries = []
  125. seen = set()
  126. for query in refined_queries:
  127. if query and len(query.strip()) > 0 and len(query.strip()) < 100:
  128. if query.strip() not in seen:
  129. validated_queries.append(query.strip())
  130. seen.add(query.strip())
  131. # 限制最终数量
  132. if len(validated_queries) > 10:
  133. validated_queries = validated_queries[:10]
  134. # 确保至少有一个查询词
  135. if not validated_queries:
  136. validated_queries = [state["question"]]
  137. logger.info(f"查询词验证结果: {validated_queries}")
  138. state["refined_queries"] = validated_queries
  139. return state
  140. def _classify_queries(self, state: AgentState) -> AgentState:
  141. """推测每个查询词的知识类型并写入result_queries"""
  142. refined_queries = state.get("refined_queries", [])
  143. # 使用大模型进行分类
  144. result_items: List[Dict[str, str]] = self._classify_with_llm(refined_queries)
  145. state["result_queries"] = result_items
  146. return state
  147. def _save_queries(self, state: AgentState) -> AgentState:
  148. """保存查询词到外部接口节点"""
  149. refined_queries = state["refined_queries"]
  150. question = state["question"]
  151. if not refined_queries:
  152. logger.warning("没有查询词需要保存")
  153. return state
  154. # 调用外部接口保存查询词(按类型分组)
  155. try:
  156. url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
  157. headers = {"Content-Type": "application/json"}
  158. # 仅使用前一步的分类结果,不做即时分类
  159. result_items: List[Dict[str, str]] = state.get("result_queries", [])
  160. if not result_items:
  161. logger.warning("缺少分类结果result_queries,跳过外部提交")
  162. return state
  163. if result_items:
  164. with httpx.Client() as client:
  165. data_content = {"queryWords": result_items}
  166. data_content = json.dumps(data_content)
  167. data_content = data_content.encode('utf-8')
  168. logger.info(f"查询词保存数据: {data_content}")
  169. resp1 = client.post(url, headers=headers, data=data_content, timeout=30)
  170. logger.info(f"查询词保存结果: {resp1.json()}")
  171. resp1.raise_for_status()
  172. logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
  173. except httpx.HTTPError as e:
  174. logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
  175. except Exception as e:
  176. logger.error(f"保存查询词时发生错误: {str(e)}")
  177. return state
  178. def _infer_knowledge_type(self, query: str) -> str:
  179. """根据查询词简单推断知识类型(内容知识/工具知识)"""
  180. tool_keywords = [
  181. "安装", "配置", "使用", "教程", "API", "SDK", "命令", "指令", "版本",
  182. "错误", "异常", "调试", "部署", "集成", "调用", "参数", "示例", "代码",
  183. "CLI", "tool", "library", "framework"
  184. ]
  185. lower_q = query.lower()
  186. for kw in tool_keywords:
  187. if kw.lower() in lower_q:
  188. return "工具知识"
  189. return "内容知识"
  190. def _classify_with_llm(self, queries: List[str]) -> List[Dict[str, str]]:
  191. """调用LLM将查询词分类为 内容知识 / 工具知识。
  192. 返回形如 [{"query": q, "knowledgeType": "内容知识"|"工具知识"}, ...]
  193. 若解析失败,降级为将所有查询标记为 内容知识(不使用关键词启发)。
  194. """
  195. if not queries:
  196. return []
  197. instruction = (
  198. "你是一名分类助手。请将下面的查询词逐一分类为‘内容知识’或‘工具知识’。\n"
  199. "请只返回严格的JSON数组,每个元素为对象:{\"query\": 原始查询词, \"knowledgeType\": \"内容知识\" 或 \"工具知识\"}。\n"
  200. "不要输出任何解释或多余文本。"
  201. )
  202. payload = "\n".join(queries)
  203. prompt = ChatPromptTemplate.from_messages([
  204. SystemMessage(content=instruction),
  205. HumanMessage(content=f"查询词列表(每行一个):\n{payload}")
  206. ])
  207. try:
  208. response = self.llm.invoke(prompt.format_messages())
  209. text = (response.content or "").strip()
  210. logger.info(f"LLM分类结果: {text}")
  211. # 尝试解析为JSON数组;若失败,尝试从代码块或文本中提取
  212. try:
  213. data = json.loads(text)
  214. except Exception:
  215. data = self._extract_json_array_from_text(text)
  216. result: List[Dict[str, str]] = []
  217. for item in data:
  218. q = str(item.get("query", "")).strip()
  219. kt = str(item.get("knowledgeType", "")).strip()
  220. if q and kt in ("内容知识", "工具知识"):
  221. result.append({"query": q, "knowledgeType": kt})
  222. # 保证顺序与输入一致,且都包含
  223. if len(result) != len(queries):
  224. # 尝试基于输入进行对齐
  225. mapped = {it["query"]: it["knowledgeType"] for it in result}
  226. aligned: List[Dict[str, str]] = []
  227. for q in queries:
  228. kt = mapped.get(q, "内容知识")
  229. aligned.append({"query": q, "knowledgeType": kt})
  230. return aligned
  231. return result
  232. except Exception as e:
  233. # 降级:全部标注为内容知识(不做关键词匹配)
  234. logger.warning(f"LLM分类失败,使用降级策略: {e}")
  235. return [{"query": q, "knowledgeType": "内容知识"} for q in queries]
  236. def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]:
  237. """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。"""
  238. s = (text or "").strip()
  239. # 去除三引号包裹的代码块
  240. if s.startswith("```"):
  241. # 去掉第一行的 ``` 或 ```json
  242. first_newline = s.find('\n')
  243. if first_newline != -1:
  244. s = s[first_newline + 1:]
  245. if s.endswith("```"):
  246. s = s[:-3]
  247. s = s.strip()
  248. # 在文本中查找首个JSON数组
  249. import re
  250. match = re.search(r"\[[\s\S]*\]", s)
  251. if not match:
  252. raise ValueError("未找到JSON数组片段")
  253. json_str = match.group(0)
  254. data = json.loads(json_str)
  255. if not isinstance(data, list):
  256. raise ValueError("提取内容不是JSON数组")
  257. return data
  258. async def generate_queries(self, question: str, task_id: int = 0) -> List[str]:
  259. """
  260. 生成查询词的主入口
  261. Args:
  262. question: 用户问题
  263. task_id: 任务ID
  264. Returns:
  265. 生成的查询词列表
  266. """
  267. initial_state = {
  268. "question": question,
  269. "task_id": task_id,
  270. "initial_queries": [],
  271. "refined_queries": [],
  272. "result_queries": [],
  273. "context": "",
  274. "iteration_count": 0
  275. }
  276. try:
  277. result = await self.graph.ainvoke(initial_state)
  278. return result["result_queries"]
  279. except Exception as e:
  280. # 更新任务状态为失败
  281. if task_id > 0:
  282. self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
  283. # 降级处理:返回原始问题
  284. return [question]