query_agent.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. from typing import List, Dict, Any, TypedDict
  2. from langgraph.graph import StateGraph, END
  3. from langchain_google_genai import ChatGoogleGenerativeAI
  4. from langchain.prompts import ChatPromptTemplate
  5. from langchain.schema import HumanMessage, SystemMessage
  6. import httpx
  7. import json
  8. from ..tools.prompts import STRUCTURED_TOOL_DEMAND_PROMPT
  9. from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
  10. class AgentState(TypedDict):
  11. """Agent状态定义"""
  12. question: str
  13. task_id: int
  14. initial_queries: List[str]
  15. refined_queries: List[str]
  16. result_queries: List[Dict[str, str]]
  17. knowledgeType: str
  18. class QueryGenerationAgent:
  19. """查询词生成Agent"""
  20. def __init__(self, gemini_api_key: str, model_name: str = "gemini-1.5-pro"):
  21. """
  22. 初始化Agent
  23. Args:
  24. gemini_api_key: Gemini API密钥
  25. model_name: 使用的模型名称
  26. """
  27. self.llm = ChatGoogleGenerativeAI(
  28. google_api_key=gemini_api_key,
  29. model=model_name,
  30. temperature=0.7
  31. )
  32. self.task_dao = QueryTaskDAO()
  33. # 创建状态图
  34. self.graph = self._create_graph()
  35. def _create_graph(self) -> StateGraph:
  36. """创建LangGraph状态图"""
  37. workflow = StateGraph(AgentState)
  38. # 添加节点(仅保留 生成 与 保存)
  39. workflow.add_node("generate_initial_queries", self._generate_initial_queries)
  40. workflow.add_node("save_queries", self._save_queries)
  41. # 设置入口点
  42. workflow.set_entry_point("generate_initial_queries")
  43. # 添加边
  44. workflow.add_edge("generate_initial_queries", "save_queries")
  45. workflow.add_edge("save_queries", END)
  46. return workflow.compile()
  47. def _generate_initial_queries(self, state: AgentState) -> AgentState:
  48. """生成 refined_queries(从结构化JSON中聚合三类关键词)"""
  49. question = state["question"]
  50. # 使用新的结构化系统提示
  51. prompt = ChatPromptTemplate.from_messages([
  52. SystemMessage(content=STRUCTURED_TOOL_DEMAND_PROMPT),
  53. HumanMessage(content=question)
  54. ])
  55. try:
  56. response = self.llm.invoke(prompt.format_messages())
  57. text = (response.content or "").strip()
  58. # 解析严格的JSON数组;若失败,尝试从文本中提取
  59. try:
  60. data = json.loads(text)
  61. except Exception:
  62. data = self._extract_json_array_from_text(text)
  63. logger.info(f"需求分析结果: {data}")
  64. aggregated: List[str] = []
  65. for item in data:
  66. ek = (item or {}).get("expanded_keywords", {})
  67. g = ek.get("general_discovery_queries", []) or []
  68. t = ek.get("themed_function_queries", []) or []
  69. h = ek.get("how_to_use_queries", []) or []
  70. for q in [*g, *t, *h]:
  71. q_str = str(q).strip()
  72. if q_str:
  73. aggregated.append(q_str)
  74. # 去重,保持顺序
  75. seen = set()
  76. deduped: List[str] = []
  77. for q in aggregated:
  78. if q not in seen:
  79. seen.add(q)
  80. deduped.append(q)
  81. state["initial_queries"] = deduped
  82. state["refined_queries"] = deduped
  83. except Exception as e:
  84. logger.warning(f"结构化需求解析失败,降级为原始问题: {e}")
  85. state["initial_queries"] = [question]
  86. state["refined_queries"] = [question]
  87. return state
  88. # 删除 refine/validate/classify 节点
  89. def _save_queries(self, state: AgentState) -> AgentState:
  90. """保存查询词到外部接口节点"""
  91. refined_queries = state.get("refined_queries", [])
  92. question = state.get("question", "")
  93. knowledge_type = state.get("knowledgeType", "") or "内容知识"
  94. if not refined_queries:
  95. logger.warning("没有查询词需要保存")
  96. return state
  97. # 合并 knowledgeType 与每个查询词,形成提交数据
  98. result_items: List[Dict[str, str]] = [
  99. {"query": q, "knowledgeType": knowledge_type} for q in refined_queries
  100. ]
  101. state["result_queries"] = result_items
  102. try:
  103. url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
  104. headers = {"Content-Type": "application/json"}
  105. with httpx.Client() as client:
  106. data_content = result_items
  107. logger.info(f"查询词保存数据: {data_content}")
  108. resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
  109. resp1.raise_for_status()
  110. logger.info(f"查询词保存结果: {resp1.text}")
  111. logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
  112. except httpx.HTTPError as e:
  113. logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
  114. except Exception as e:
  115. logger.error(f"保存查询词时发生错误: {str(e)}")
  116. return state
  117. def _infer_knowledge_type(self, query: str) -> str:
  118. """根据查询词简单推断知识类型(内容知识/工具知识)"""
  119. tool_keywords = [
  120. "安装", "配置", "使用", "教程", "API", "SDK", "命令", "指令", "版本",
  121. "错误", "异常", "调试", "部署", "集成", "调用", "参数", "示例", "代码",
  122. "CLI", "tool", "library", "framework"
  123. ]
  124. lower_q = query.lower()
  125. for kw in tool_keywords:
  126. if kw.lower() in lower_q:
  127. return "工具知识"
  128. return "内容知识"
  129. def _classify_with_llm(self, queries: List[str]) -> List[Dict[str, str]]:
  130. """调用LLM将查询词分类为 内容知识 / 工具知识。
  131. 返回形如 [{"query": q, "knowledgeType": "内容知识"|"工具知识"}, ...]
  132. 若解析失败,降级为将所有查询标记为 内容知识(不使用关键词启发)。
  133. """
  134. if not queries:
  135. return []
  136. instruction = (
  137. "你是一名分类助手。请将下面的查询词逐一分类为‘内容知识’或‘工具知识’。\n"
  138. "请只返回严格的JSON数组,每个元素为对象:{\"query\": 原始查询词, \"knowledgeType\": \"内容知识\" 或 \"工具知识\"}。\n"
  139. "不要输出任何解释或多余文本。"
  140. )
  141. payload = "\n".join(queries)
  142. prompt = ChatPromptTemplate.from_messages([
  143. SystemMessage(content=instruction),
  144. HumanMessage(content=f"查询词列表(每行一个):\n{payload}")
  145. ])
  146. try:
  147. response = self.llm.invoke(prompt.format_messages())
  148. text = (response.content or "").strip()
  149. logger.info(f"LLM分类结果: {text}")
  150. # 尝试解析为JSON数组;若失败,尝试从代码块或文本中提取
  151. try:
  152. data = json.loads(text)
  153. except Exception:
  154. data = self._extract_json_array_from_text(text)
  155. result: List[Dict[str, str]] = []
  156. for item in data:
  157. q = str(item.get("query", "")).strip()
  158. kt = str(item.get("knowledgeType", "")).strip()
  159. if q and kt in ("内容知识", "工具知识"):
  160. result.append({"query": q, "knowledgeType": kt})
  161. # 保证顺序与输入一致,且都包含
  162. if len(result) != len(queries):
  163. # 尝试基于输入进行对齐
  164. mapped = {it["query"]: it["knowledgeType"] for it in result}
  165. aligned: List[Dict[str, str]] = []
  166. for q in queries:
  167. kt = mapped.get(q, "内容知识")
  168. aligned.append({"query": q, "knowledgeType": kt})
  169. return aligned
  170. return result
  171. except Exception as e:
  172. # 降级:全部标注为内容知识(不做关键词匹配)
  173. logger.warning(f"LLM分类失败,使用降级策略: {e}")
  174. return [{"query": q, "knowledgeType": "内容知识"} for q in queries]
  175. def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]:
  176. """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。"""
  177. s = (text or "").strip()
  178. # 去除三引号包裹的代码块
  179. if s.startswith("```"):
  180. # 去掉第一行的 ``` 或 ```json
  181. first_newline = s.find('\n')
  182. if first_newline != -1:
  183. s = s[first_newline + 1:]
  184. if s.endswith("```"):
  185. s = s[:-3]
  186. s = s.strip()
  187. # 在文本中查找首个JSON数组
  188. import re
  189. match = re.search(r"\[[\s\S]*\]", s)
  190. if not match:
  191. raise ValueError("未找到JSON数组片段")
  192. json_str = match.group(0)
  193. data = json.loads(json_str)
  194. if not isinstance(data, list):
  195. raise ValueError("提取内容不是JSON数组")
  196. return data
  197. async def generate_queries(self, question: str, task_id: int = 0, knowledge_type: str = "") -> List[str]:
  198. """
  199. 生成查询词的主入口
  200. Args:
  201. question: 用户问题
  202. task_id: 任务ID
  203. Returns:
  204. 生成的查询词列表
  205. """
  206. initial_state = {
  207. "question": question,
  208. "task_id": task_id,
  209. "initial_queries": [],
  210. "refined_queries": [],
  211. "result_queries": [],
  212. "knowledgeType": knowledge_type or "内容知识"
  213. }
  214. try:
  215. result = await self.graph.ainvoke(initial_state)
  216. return result["result_queries"]
  217. except Exception as e:
  218. # 更新任务状态为失败
  219. if task_id > 0:
  220. self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
  221. # 降级处理:返回原始问题
  222. return [question]