query_agent.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. from typing import List, Dict, Any, TypedDict
  2. from langgraph.graph import StateGraph, END
  3. from langchain_google_genai import ChatGoogleGenerativeAI
  4. from langchain.prompts import ChatPromptTemplate
  5. from langchain.schema import HumanMessage, SystemMessage
  6. import httpx
  7. import json
  8. from ..tools.prompts import (
  9. STRUCTURED_TOOL_DEMAND_PROMPT,
  10. CLASSIFICATION_PROMPT,
  11. QUERY_CLASSIFICATION_PROMPT
  12. )
  13. from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
  14. class AgentState(TypedDict):
  15. """Agent状态定义"""
  16. question: str
  17. task_id: int
  18. need_store: int
  19. initial_queries: List[str]
  20. refined_queries: List[str]
  21. result_queries: List[Dict[str, str]]
  22. knowledgeType: str
  23. content_dimension: str # 内容类型的维度: How / What / Pattern
  24. is_query_type: bool # 是否为可处理的查询类型
  25. class QueryGenerationAgent:
  26. """查询词生成Agent"""
  27. def __init__(self, gemini_api_key: str, model_name: str = "gemini-1.5-pro"):
  28. """
  29. 初始化Agent
  30. Args:
  31. gemini_api_key: Gemini API密钥
  32. model_name: 使用的模型名称
  33. """
  34. self.llm = ChatGoogleGenerativeAI(
  35. google_api_key=gemini_api_key,
  36. model=model_name,
  37. temperature=0.7
  38. )
  39. self.task_dao = QueryTaskDAO()
  40. # 创建状态图
  41. self.graph = self._create_graph()
  42. def _create_graph(self) -> StateGraph:
  43. """创建LangGraph状态图"""
  44. workflow = StateGraph(AgentState)
  45. # 添加节点
  46. workflow.add_node("classify_question", self._classify_question)
  47. workflow.add_node("generate_tool_queries", self._generate_tool_queries) # 工具类型查询生成
  48. workflow.add_node("classify_content_dimension", self._classify_content_dimension) # 内容维度分类
  49. workflow.add_node("expand_content_queries", self._expand_content_queries) # 内容查询扩展
  50. workflow.add_node("save_queries", self._save_queries)
  51. # 设置入口点
  52. workflow.set_entry_point("classify_question")
  53. # 条件路由:工具知识 vs 内容知识
  54. try:
  55. workflow.add_conditional_edges(
  56. "classify_question",
  57. self._route_after_classify,
  58. {
  59. "TOOL": "generate_tool_queries",
  60. "CONTENT": "classify_content_dimension"
  61. }
  62. )
  63. except Exception:
  64. workflow.add_edge("classify_question", "generate_tool_queries")
  65. # 工具类型:生成 -> 保存 -> 结束
  66. workflow.add_edge("generate_tool_queries", "save_queries")
  67. # 内容类型:分类维度 -> 条件路由
  68. try:
  69. workflow.add_conditional_edges(
  70. "classify_content_dimension",
  71. self._route_after_content_classify,
  72. {
  73. "EXPAND": "expand_content_queries",
  74. "UNSUPPORTED": END
  75. }
  76. )
  77. except Exception:
  78. workflow.add_edge("classify_content_dimension", "expand_content_queries")
  79. # 内容扩展:扩展 -> 保存 -> 结束
  80. workflow.add_edge("expand_content_queries", "save_queries")
  81. workflow.add_edge("save_queries", END)
  82. return workflow.compile()
  83. def _classify_question(self, state: AgentState) -> AgentState:
  84. """判断问题知识类型:工具知识 / 内容知识"""
  85. question = state.get("question", "")
  86. instruction = (
  87. "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n"
  88. "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n"
  89. "- 内容知识:话题洞察、趋势、创作灵感、正文内容、案例分析、概念解释、非工具操作的问题。\n"
  90. "要求:严格只输出两个词之一——工具知识 或 内容知识;不要输出任何其它字符、解释或标点。"
  91. )
  92. prompt = ChatPromptTemplate.from_messages([
  93. SystemMessage(content=instruction),
  94. HumanMessage(content=question)
  95. ])
  96. try:
  97. response = self.llm.invoke(prompt.format_messages())
  98. text = (response.content or "").strip()
  99. logger.info(f"问题类型判断结果: {text}")
  100. kt = "工具知识" if "工具" in text else "内容知识"
  101. state["knowledgeType"] = kt
  102. except Exception as e:
  103. # 失败默认判为内容知识以避免误触发
  104. logger.warning(f"问题类型判断失败: {e}")
  105. state["knowledgeType"] = "内容知识"
  106. return state
  107. def _route_after_classify(self, state: AgentState) -> str:
  108. """根据分类结果路由:工具 -> TOOL;内容 -> CONTENT"""
  109. return "TOOL" if state.get("knowledgeType") == "工具知识" else "CONTENT"
  110. def _generate_tool_queries(self, state: AgentState) -> AgentState:
  111. """生成工具类型的查询词(从结构化JSON中聚合三类关键词)"""
  112. question = state["question"]
  113. # 使用新的结构化系统提示
  114. prompt = ChatPromptTemplate.from_messages([
  115. SystemMessage(content=STRUCTURED_TOOL_DEMAND_PROMPT),
  116. HumanMessage(content=question)
  117. ])
  118. try:
  119. response = self.llm.invoke(prompt.format_messages())
  120. text = (response.content or "").strip()
  121. # 解析严格的JSON数组;若失败,尝试从文本中提取
  122. try:
  123. data = json.loads(text)
  124. except Exception:
  125. data = self._extract_json_array_from_text(text)
  126. logger.info(f"需求分析结果: {data}")
  127. aggregated: List[str] = []
  128. for item in data:
  129. ek = (item or {}).get("expanded_keywords", {})
  130. g = ek.get("general_discovery_queries", []) or []
  131. t = ek.get("themed_function_queries", []) or []
  132. h = ek.get("how_to_use_queries", []) or []
  133. for q in [*g, *t, *h]:
  134. q_str = str(q).strip()
  135. if q_str:
  136. aggregated.append(q_str)
  137. # 去重,保持顺序
  138. seen = set()
  139. deduped: List[str] = []
  140. for q in aggregated:
  141. if q not in seen:
  142. seen.add(q)
  143. deduped.append(q)
  144. state["initial_queries"] = deduped
  145. state["refined_queries"] = deduped
  146. except Exception as e:
  147. logger.warning(f"结构化需求解析失败,降级为原始问题: {e}")
  148. state["initial_queries"] = [question]
  149. state["refined_queries"] = [question]
  150. return state
  151. def _classify_content_dimension(self, state: AgentState) -> AgentState:
  152. """使用CLASSIFICATION_PROMPT对内容类型问题进行维度分类(How/What/Pattern)"""
  153. question = state["question"]
  154. prompt = ChatPromptTemplate.from_messages([
  155. SystemMessage(content=CLASSIFICATION_PROMPT),
  156. HumanMessage(content=question)
  157. ])
  158. try:
  159. response = self.llm.invoke(prompt.format_messages())
  160. text = (response.content or "").strip()
  161. logger.info(f"内容维度分类结果: {text}")
  162. # 解析JSON结果
  163. try:
  164. data = json.loads(text)
  165. except Exception:
  166. data = self._extract_json_from_text(text)
  167. dimension = data.get("所属维度", "").strip()
  168. state["content_dimension"] = dimension
  169. # 判断是否为可处理的查询类型(目前仅支持How类型)
  170. state["is_query_type"] = dimension == "How"
  171. if not state["is_query_type"]:
  172. # 不支持的类型,标记任务失败
  173. error_msg = f"暂不支持{dimension}类型的内容问题,当前仅支持How类型"
  174. logger.info(error_msg)
  175. if state.get("task_id", 0) > 0:
  176. self.task_dao.mark_task_failed(state["task_id"], error_msg)
  177. state["result_queries"] = []
  178. except Exception as e:
  179. logger.error(f"内容维度分类失败: {e}")
  180. state["is_query_type"] = False
  181. if state.get("task_id", 0) > 0:
  182. self.task_dao.mark_task_failed(state["task_id"], f"分类失败: {str(e)}")
  183. state["result_queries"] = []
  184. return state
  185. def _route_after_content_classify(self, state: AgentState) -> str:
  186. """根据内容分类结果路由:支持的类型 -> EXPAND;不支持 -> UNSUPPORTED"""
  187. return "EXPAND" if state.get("is_query_type", False) else "UNSUPPORTED"
  188. def _expand_content_queries(self, state: AgentState) -> AgentState:
  189. """使用QUERY_CLASSIFICATION_PROMPT扩展内容类型的查询词"""
  190. question = state["question"]
  191. prompt = ChatPromptTemplate.from_messages([
  192. SystemMessage(content=QUERY_CLASSIFICATION_PROMPT),
  193. HumanMessage(content=question)
  194. ])
  195. try:
  196. response = self.llm.invoke(prompt.format_messages())
  197. text = (response.content or "").strip()
  198. logger.info(f"查询扩展结果: {text}")
  199. # 解析JSON结果
  200. try:
  201. data = json.loads(text)
  202. except Exception:
  203. data = self._extract_json_from_text(text)
  204. # 提取所有扩展的查询词
  205. expanded = data.get("expanded_queries", {})
  206. aggregated: List[str] = []
  207. invalid_keywords = ["无关", "超出", "不相关", "不属于", "无法生成"]
  208. # 收集粗颗粒度查询并检测是否不符合创作领域
  209. for item in expanded.get("coarse_grained", []) or []:
  210. q = str(item.get("query", "")).strip()
  211. reason = str(item.get("reason", "")).strip()
  212. # 检测是否表明问题不符合创作领域
  213. if q and any(keyword in q for keyword in invalid_keywords):
  214. error_msg = q if len(q) <= 100 else reason[:100] if reason else "问题不符合内容创作领域"
  215. logger.info(f"检测到不符合创作领域的问题: {error_msg}")
  216. if state.get("task_id", 0) > 0:
  217. self.task_dao.mark_task_failed(state["task_id"], error_msg)
  218. state["result_queries"] = []
  219. state["initial_queries"] = []
  220. state["refined_queries"] = []
  221. return state
  222. if q:
  223. aggregated.append(q)
  224. # 收集细颗粒度查询
  225. for item in expanded.get("fine_grained", []) or []:
  226. q = str(item.get("query", "")).strip()
  227. if q:
  228. aggregated.append(q)
  229. # 收集互补或差异化查询
  230. for item in expanded.get("complementary_or_differentiated", []) or []:
  231. q = str(item.get("query", "")).strip()
  232. if q:
  233. aggregated.append(q)
  234. # 如果所有查询词都为空,可能表示无法生成有效查询
  235. if not aggregated:
  236. error_msg = "无法生成有效的内容创作查询词"
  237. logger.info(error_msg)
  238. if state.get("task_id", 0) > 0:
  239. self.task_dao.mark_task_failed(state["task_id"], error_msg)
  240. state["result_queries"] = []
  241. state["initial_queries"] = []
  242. state["refined_queries"] = []
  243. return state
  244. # 去重,保持顺序
  245. seen = set()
  246. deduped: List[str] = []
  247. for q in aggregated:
  248. if q not in seen:
  249. seen.add(q)
  250. deduped.append(q)
  251. state["initial_queries"] = deduped
  252. state["refined_queries"] = deduped
  253. except Exception as e:
  254. logger.warning(f"查询扩展失败,降级为原始问题: {e}")
  255. state["initial_queries"] = [question]
  256. state["refined_queries"] = [question]
  257. return state
  258. def _save_queries(self, state: AgentState) -> AgentState:
  259. """保存查询词到外部接口节点"""
  260. refined_queries = state.get("refined_queries", [])
  261. question = state.get("question", "")
  262. knowledge_type = state.get("knowledgeType", "") or "内容知识"
  263. if not refined_queries:
  264. logger.warning("没有查询词需要保存")
  265. return state
  266. # 合并 knowledgeType 与每个查询词,附加 task_id,形成提交数据
  267. result_items: List[Dict[str, str]] = [
  268. {"query": q, "knowledgeType": knowledge_type, "task_id": state.get("task_id", 0)} for q in refined_queries
  269. ]
  270. state["result_queries"] = result_items
  271. # need_store=1 保存查询词
  272. if state.get("need_store", 1) == 1:
  273. try:
  274. url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
  275. headers = {"Content-Type": "application/json"}
  276. with httpx.Client() as client:
  277. data_content = result_items
  278. logger.info(f"查询词保存数据: {data_content}")
  279. resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
  280. resp1.raise_for_status()
  281. logger.info(f"查询词保存结果: {resp1.text}")
  282. logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
  283. except httpx.HTTPError as e:
  284. logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
  285. except Exception as e:
  286. logger.error(f"保存查询词时发生错误: {str(e)}")
  287. return state
  288. def _infer_knowledge_type(self, query: str) -> str:
  289. """根据查询词简单推断知识类型(内容知识/工具知识)"""
  290. tool_keywords = [
  291. "安装", "配置", "使用", "教程", "API", "SDK", "命令", "指令", "版本",
  292. "错误", "异常", "调试", "部署", "集成", "调用", "参数", "示例", "代码",
  293. "CLI", "tool", "library", "framework"
  294. ]
  295. lower_q = query.lower()
  296. for kw in tool_keywords:
  297. if kw.lower() in lower_q:
  298. return "工具知识"
  299. return "内容知识"
  300. def _classify_with_llm(self, queries: List[str]) -> List[Dict[str, str]]:
  301. """调用LLM将查询词分类为 内容知识 / 工具知识。
  302. 返回形如 [{"query": q, "knowledgeType": "内容知识"|"工具知识"}, ...]
  303. 若解析失败,降级为将所有查询标记为 内容知识(不使用关键词启发)。
  304. """
  305. if not queries:
  306. return []
  307. instruction = (
  308. "你是一名分类助手。请将下面的查询词逐一分类为‘内容知识’或‘工具知识’。\n"
  309. "请只返回严格的JSON数组,每个元素为对象:{\"query\": 原始查询词, \"knowledgeType\": \"内容知识\" 或 \"工具知识\"}。\n"
  310. "不要输出任何解释或多余文本。"
  311. )
  312. payload = "\n".join(queries)
  313. prompt = ChatPromptTemplate.from_messages([
  314. SystemMessage(content=instruction),
  315. HumanMessage(content=f"查询词列表(每行一个):\n{payload}")
  316. ])
  317. try:
  318. response = self.llm.invoke(prompt.format_messages())
  319. text = (response.content or "").strip()
  320. logger.info(f"LLM分类结果: {text}")
  321. # 尝试解析为JSON数组;若失败,尝试从代码块或文本中提取
  322. try:
  323. data = json.loads(text)
  324. except Exception:
  325. data = self._extract_json_array_from_text(text)
  326. result: List[Dict[str, str]] = []
  327. for item in data:
  328. q = str(item.get("query", "")).strip()
  329. kt = str(item.get("knowledgeType", "")).strip()
  330. if q and kt in ("内容知识", "工具知识"):
  331. result.append({"query": q, "knowledgeType": kt})
  332. # 保证顺序与输入一致,且都包含
  333. if len(result) != len(queries):
  334. # 尝试基于输入进行对齐
  335. mapped = {it["query"]: it["knowledgeType"] for it in result}
  336. aligned: List[Dict[str, str]] = []
  337. for q in queries:
  338. kt = mapped.get(q, "内容知识")
  339. aligned.append({"query": q, "knowledgeType": kt})
  340. return aligned
  341. return result
  342. except Exception as e:
  343. # 降级:全部标注为内容知识(不做关键词匹配)
  344. logger.warning(f"LLM分类失败,使用降级策略: {e}")
  345. return [{"query": q, "knowledgeType": "内容知识"} for q in queries]
  346. def _extract_json_from_text(self, text: str) -> Dict[str, Any]:
  347. """从模型输出中提取JSON对象(可能包含```json代码块或多余文本)"""
  348. s = (text or "").strip()
  349. # 去除三引号包裹的代码块
  350. if s.startswith("```"):
  351. # 去掉第一行的 ``` 或 ```json
  352. first_newline = s.find('\n')
  353. if first_newline != -1:
  354. s = s[first_newline + 1:]
  355. if s.endswith("```"):
  356. s = s[:-3]
  357. s = s.strip()
  358. # 在文本中查找首个JSON对象
  359. import re
  360. match = re.search(r"\{[\s\S]*\}", s)
  361. if not match:
  362. raise ValueError("未找到JSON对象片段")
  363. json_str = match.group(0)
  364. data = json.loads(json_str)
  365. if not isinstance(data, dict):
  366. raise ValueError("提取内容不是JSON对象")
  367. return data
  368. def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]:
  369. """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。"""
  370. s = (text or "").strip()
  371. # 去除三引号包裹的代码块
  372. if s.startswith("```"):
  373. # 去掉第一行的 ``` 或 ```json
  374. first_newline = s.find('\n')
  375. if first_newline != -1:
  376. s = s[first_newline + 1:]
  377. if s.endswith("```"):
  378. s = s[:-3]
  379. s = s.strip()
  380. # 在文本中查找首个JSON数组
  381. import re
  382. match = re.search(r"\[[\s\S]*\]", s)
  383. if not match:
  384. raise ValueError("未找到JSON数组片段")
  385. json_str = match.group(0)
  386. data = json.loads(json_str)
  387. if not isinstance(data, list):
  388. raise ValueError("提取内容不是JSON数组")
  389. return data
  390. async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> List[str]:
  391. """
  392. 生成查询词的主入口
  393. Args:
  394. question: 用户问题
  395. task_id: 任务ID
  396. knowledge_type: 知识类型(可选,用于兼容)
  397. Returns:
  398. 生成的查询词列表
  399. """
  400. initial_state = {
  401. "question": question,
  402. "task_id": task_id,
  403. "need_store": need_store,
  404. "initial_queries": [],
  405. "refined_queries": [],
  406. "result_queries": [],
  407. "knowledgeType": "",
  408. "content_dimension": "",
  409. "is_query_type": False
  410. }
  411. try:
  412. result = await self.graph.ainvoke(initial_state)
  413. return result["result_queries"]
  414. except Exception as e:
  415. logger.error(f"生成查询词失败: {e}")
  416. # 更新任务状态为失败
  417. if task_id > 0:
  418. self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
  419. # 降级处理:返回原始问题
  420. return [question]
  421. def is_tool_question(self, question: str) -> bool:
  422. """同步判断问题是否为工具知识类型。"""
  423. instruction = (
  424. "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n"
  425. "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n"
  426. "- 内容知识:话题洞察、趋势、创作灵感、正文内容、案例分析、概念解释、非工具操作的问题。\n"
  427. "要求:严格只输出两个词之一——工具知识 或 内容知识;不要输出任何其它字符、解释或标点。"
  428. )
  429. prompt = ChatPromptTemplate.from_messages([
  430. SystemMessage(content=instruction),
  431. HumanMessage(content=question)
  432. ])
  433. try:
  434. response = self.llm.invoke(prompt.format_messages())
  435. text = (response.content or "").strip()
  436. return "工具" in text
  437. except Exception:
  438. return False