from typing import List, Dict, Any, TypedDict from langgraph.graph import StateGraph, END from langchain_google_genai import ChatGoogleGenerativeAI from langchain.prompts import ChatPromptTemplate from langchain.schema import HumanMessage, SystemMessage import httpx import json from ..tools.prompts import ( STRUCTURED_TOOL_DEMAND_PROMPT, CLASSIFICATION_PROMPT, QUERY_CLASSIFICATION_PROMPT, WHAT_CLASSIFICATION_PROMPT, PATTERN_CLASSIFICATION_PROMPT ) from ..database.models import QueryTaskDAO, QueryTaskStatus, logger class AgentState(TypedDict): """Agent状态定义""" question: str task_id: int need_store: int initial_queries: List[str] refined_queries: List[str] result_queries: List[Dict[str, str]] knowledgeType: str query_type: str # 问题类型: How / What / Pattern class QueryGenerationAgent: """查询词生成Agent""" def __init__(self, gemini_api_key: str, model_name: str = "gemini-1.5-pro"): """ 初始化Agent Args: gemini_api_key: Gemini API密钥 model_name: 使用的模型名称 """ self.llm = ChatGoogleGenerativeAI( google_api_key=gemini_api_key, model=model_name, temperature=0.7 ) self.task_dao = QueryTaskDAO() # 创建状态图 self.graph = self._create_graph() def _normalize_query_type(self, query_type: str) -> str: """统一规范化query_type为首字母大写格式(How/What/Pattern)""" query_type_lower = query_type.strip().lower() if query_type_lower == "how": return "How" elif query_type_lower == "what": return "What" elif query_type_lower == "pattern": return "Pattern" else: return query_type # 返回原值 def _create_graph(self) -> StateGraph: """创建LangGraph状态图""" workflow = StateGraph(AgentState) # 添加节点 workflow.add_node("classify_question", self._classify_question) workflow.add_node("generate_tool_queries", self._generate_tool_queries) # 工具类型查询生成 workflow.add_node("classify_content_dimension", self._classify_content_dimension) # 内容维度分类 workflow.add_node("expand_content_queries", self._expand_content_queries) # 内容查询扩展 workflow.add_node("save_queries", self._save_queries) # 设置入口点 workflow.set_entry_point("classify_question") # 条件路由:工具知识 vs 内容知识 try: workflow.add_conditional_edges( "classify_question", self._route_after_classify, { "TOOL": "generate_tool_queries", "CONTENT": "classify_content_dimension" } ) except Exception: workflow.add_edge("classify_question", "generate_tool_queries") # 工具类型:生成 -> 保存 -> 结束 workflow.add_edge("generate_tool_queries", "save_queries") # 内容类型:分类维度 -> 条件路由 try: workflow.add_conditional_edges( "classify_content_dimension", self._route_after_content_classify, { "EXPAND": "expand_content_queries", "UNSUPPORTED": END } ) except Exception: workflow.add_edge("classify_content_dimension", "expand_content_queries") # 内容扩展:扩展 -> 保存 -> 结束 workflow.add_edge("expand_content_queries", "save_queries") workflow.add_edge("save_queries", END) return workflow.compile() def _classify_question(self, state: AgentState) -> AgentState: """判断问题知识类型:工具知识 / 内容知识""" question = state.get("question", "") instruction = ( "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n" "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n" "- 内容知识:话题洞察、趋势、创作灵感、正文内容、案例分析、概念解释、非工具操作的问题。\n" "要求:严格只输出两个词之一——工具知识 或 内容知识;不要输出任何其它字符、解释或标点。" ) prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=instruction), HumanMessage(content=question) ]) try: response = self.llm.invoke(prompt.format_messages()) text = (response.content or "").strip() logger.info(f"问题类型判断结果: {text}") kt = "工具知识" if "工具" in text else "内容知识" state["knowledgeType"] = kt except Exception as e: # 失败默认判为内容知识以避免误触发 logger.warning(f"问题类型判断失败: {e}") state["knowledgeType"] = "内容知识" return state def _route_after_classify(self, state: AgentState) -> str: """根据分类结果路由:工具 -> TOOL;内容 -> CONTENT""" return "TOOL" if state.get("knowledgeType") == "工具知识" else "CONTENT" def _generate_tool_queries(self, state: AgentState) -> AgentState: """生成工具类型的查询词(从结构化JSON中聚合三类关键词)""" question = state["question"] # 使用新的结构化系统提示 prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=STRUCTURED_TOOL_DEMAND_PROMPT), HumanMessage(content=question) ]) try: response = self.llm.invoke(prompt.format_messages()) text = (response.content or "").strip() # 解析严格的JSON数组;若失败,尝试从文本中提取 try: data = json.loads(text) except Exception: data = self._extract_json_array_from_text(text) logger.info(f"需求分析结果: {data}") aggregated: List[str] = [] for item in data: ek = (item or {}).get("expanded_keywords", {}) g = ek.get("general_discovery_queries", []) or [] t = ek.get("themed_function_queries", []) or [] h = ek.get("how_to_use_queries", []) or [] for q in [*g, *t, *h]: q_str = str(q).strip() if q_str: aggregated.append(q_str) # 去重,保持顺序 seen = set() deduped: List[str] = [] for q in aggregated: if q not in seen: seen.add(q) deduped.append(q) state["initial_queries"] = deduped state["refined_queries"] = deduped except Exception as e: logger.warning(f"结构化需求解析失败,降级为原始问题: {e}") state["initial_queries"] = [question] state["refined_queries"] = [question] return state def _classify_content_dimension(self, state: AgentState) -> AgentState: """使用CLASSIFICATION_PROMPT对内容类型问题进行维度分类(How/What/Pattern)""" question = state["question"] prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=CLASSIFICATION_PROMPT), HumanMessage(content=question) ]) try: response = self.llm.invoke(prompt.format_messages()) text = (response.content or "").strip() logger.info(f"内容维度分类结果: {text}") # 解析JSON结果 try: data = json.loads(text) except Exception: data = self._extract_json_from_text(text) dimension = data.get("所属维度", "").strip() # 统一为首字母大写格式(How/What/Pattern) dimension = self._normalize_query_type(dimension) state["query_type"] = dimension logger.info(f"问题类型设置为: {dimension}") except Exception as e: logger.error(f"内容维度分类失败: {e}") if state.get("task_id", 0) > 0: self.task_dao.mark_task_failed(state["task_id"], f"分类失败: {str(e)}") state["result_queries"] = [] return state def _route_after_content_classify(self, state: AgentState) -> str: """根据内容分类结果路由:所有类型都支持扩展""" query_type = state.get("query_type", "") # 支持 How / What / Pattern 三种类型 if query_type in ["How", "What", "Pattern"]: return "EXPAND" else: # 未识别的类型,不支持 logger.warning(f"未识别的问题类型: {query_type}") return "UNSUPPORTED" def _expand_content_queries(self, state: AgentState) -> AgentState: """根据问题类型选择相应的PROMPT扩展内容查询词""" question = state["question"] query_type = state.get("query_type", "How") # 根据query_type选择对应的PROMPT(值已在分类阶段规范化为How/What/Pattern) if query_type == "How": classification_prompt = QUERY_CLASSIFICATION_PROMPT elif query_type == "What": classification_prompt = WHAT_CLASSIFICATION_PROMPT elif query_type == "Pattern": classification_prompt = PATTERN_CLASSIFICATION_PROMPT else: # 默认使用How类型的PROMPT classification_prompt = QUERY_CLASSIFICATION_PROMPT logger.warning(f"未识别的问题类型 {query_type},使用默认How类型PROMPT") logger.info(f"使用{query_type}类型的PROMPT进行查询扩展") prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=classification_prompt), HumanMessage(content=question) ]) try: response = self.llm.invoke(prompt.format_messages()) text = (response.content or "").strip() logger.info(f"查询扩展结果: {text}") # 解析JSON结果 try: data = json.loads(text) except Exception: data = self._extract_json_from_text(text) # 提取所有扩展的查询词 expanded = data.get("expanded_queries", {}) aggregated: List[str] = [] invalid_keywords = ["无关", "超出", "不相关", "不属于", "无法生成"] # 收集粗颗粒度查询并检测是否不符合创作领域 for item in expanded.get("coarse_grained", []) or []: q = str(item.get("query", "")).strip() reason = str(item.get("reason", "")).strip() # 检测是否表明问题不符合创作领域 if q and any(keyword in q for keyword in invalid_keywords): error_msg = q if len(q) <= 100 else reason[:100] if reason else "问题不符合内容创作领域" logger.info(f"检测到不符合创作领域的问题: {error_msg}") if state.get("task_id", 0) > 0: self.task_dao.mark_task_failed(state["task_id"], error_msg) state["result_queries"] = [] state["initial_queries"] = [] state["refined_queries"] = [] return state if q: aggregated.append(q) # 收集细颗粒度查询 for item in expanded.get("fine_grained", []) or []: q = str(item.get("query", "")).strip() if q: aggregated.append(q) # 收集互补或差异化查询 for item in expanded.get("complementary_or_differentiated", []) or []: q = str(item.get("query", "")).strip() if q: aggregated.append(q) # 如果所有查询词都为空,可能表示无法生成有效查询 if not aggregated: error_msg = "无法生成有效的内容创作查询词" logger.info(error_msg) if state.get("task_id", 0) > 0: self.task_dao.mark_task_failed(state["task_id"], error_msg) state["result_queries"] = [] state["initial_queries"] = [] state["refined_queries"] = [] return state # 去重,保持顺序 seen = set() deduped: List[str] = [] for q in aggregated: if q not in seen: seen.add(q) deduped.append(q) state["initial_queries"] = deduped state["refined_queries"] = deduped except Exception as e: logger.warning(f"查询扩展失败,降级为原始问题: {e}") state["initial_queries"] = [question] state["refined_queries"] = [question] return state def _save_queries(self, state: AgentState) -> AgentState: """保存查询词到外部接口节点""" refined_queries = state.get("refined_queries", []) question = state.get("question", "") knowledge_type = state.get("knowledgeType", "") or "内容知识" if not refined_queries: logger.warning("没有查询词需要保存") return state # 合并 knowledgeType 与每个查询词,附加 task_id,形成提交数据 result_items: List[Dict[str, str]] = [ {"query": q, "knowledgeType": knowledge_type, "task_id": state.get("task_id", 0)} for q in refined_queries ] state["result_queries"] = result_items # need_store=1 保存查询词 if state.get("need_store", 1) == 1: try: url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery" headers = {"Content-Type": "application/json"} with httpx.Client() as client: data_content = result_items logger.info(f"查询词保存数据: {data_content}") resp1 = client.post(url, headers=headers, json=data_content, timeout=30) resp1.raise_for_status() logger.info(f"查询词保存结果: {resp1.text}") logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}") except httpx.HTTPError as e: logger.error(f"保存查询词时发生HTTP错误: {str(e)}") except Exception as e: logger.error(f"保存查询词时发生错误: {str(e)}") return state def _infer_knowledge_type(self, query: str) -> str: """根据查询词简单推断知识类型(内容知识/工具知识)""" tool_keywords = [ "安装", "配置", "使用", "教程", "API", "SDK", "命令", "指令", "版本", "错误", "异常", "调试", "部署", "集成", "调用", "参数", "示例", "代码", "CLI", "tool", "library", "framework" ] lower_q = query.lower() for kw in tool_keywords: if kw.lower() in lower_q: return "工具知识" return "内容知识" def _classify_with_llm(self, queries: List[str]) -> List[Dict[str, str]]: """调用LLM将查询词分类为 内容知识 / 工具知识。 返回形如 [{"query": q, "knowledgeType": "内容知识"|"工具知识"}, ...] 若解析失败,降级为将所有查询标记为 内容知识(不使用关键词启发)。 """ if not queries: return [] instruction = ( "你是一名分类助手。请将下面的查询词逐一分类为‘内容知识’或‘工具知识’。\n" "请只返回严格的JSON数组,每个元素为对象:{\"query\": 原始查询词, \"knowledgeType\": \"内容知识\" 或 \"工具知识\"}。\n" "不要输出任何解释或多余文本。" ) payload = "\n".join(queries) prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=instruction), HumanMessage(content=f"查询词列表(每行一个):\n{payload}") ]) try: response = self.llm.invoke(prompt.format_messages()) text = (response.content or "").strip() logger.info(f"LLM分类结果: {text}") # 尝试解析为JSON数组;若失败,尝试从代码块或文本中提取 try: data = json.loads(text) except Exception: data = self._extract_json_array_from_text(text) result: List[Dict[str, str]] = [] for item in data: q = str(item.get("query", "")).strip() kt = str(item.get("knowledgeType", "")).strip() if q and kt in ("内容知识", "工具知识"): result.append({"query": q, "knowledgeType": kt}) # 保证顺序与输入一致,且都包含 if len(result) != len(queries): # 尝试基于输入进行对齐 mapped = {it["query"]: it["knowledgeType"] for it in result} aligned: List[Dict[str, str]] = [] for q in queries: kt = mapped.get(q, "内容知识") aligned.append({"query": q, "knowledgeType": kt}) return aligned return result except Exception as e: # 降级:全部标注为内容知识(不做关键词匹配) logger.warning(f"LLM分类失败,使用降级策略: {e}") return [{"query": q, "knowledgeType": "内容知识"} for q in queries] def _extract_json_from_text(self, text: str) -> Dict[str, Any]: """从模型输出中提取JSON对象(可能包含```json代码块或多余文本)""" s = (text or "").strip() # 去除三引号包裹的代码块 if s.startswith("```"): # 去掉第一行的 ``` 或 ```json first_newline = s.find('\n') if first_newline != -1: s = s[first_newline + 1:] if s.endswith("```"): s = s[:-3] s = s.strip() # 在文本中查找首个JSON对象 import re match = re.search(r"\{[\s\S]*\}", s) if not match: raise ValueError("未找到JSON对象片段") json_str = match.group(0) data = json.loads(json_str) if not isinstance(data, dict): raise ValueError("提取内容不是JSON对象") return data def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]: """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。""" s = (text or "").strip() # 去除三引号包裹的代码块 if s.startswith("```"): # 去掉第一行的 ``` 或 ```json first_newline = s.find('\n') if first_newline != -1: s = s[first_newline + 1:] if s.endswith("```"): s = s[:-3] s = s.strip() # 在文本中查找首个JSON数组 import re match = re.search(r"\[[\s\S]*\]", s) if not match: raise ValueError("未找到JSON数组片段") json_str = match.group(0) data = json.loads(json_str) if not isinstance(data, list): raise ValueError("提取内容不是JSON数组") return data async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> tuple[List[str], str, str]: """ 生成查询词的主入口 Args: question: 用户问题 task_id: 任务ID knowledge_type: 知识类型(可选,用于兼容) Returns: 元组:(生成的查询词列表, 问题类型) """ initial_state = { "question": question, "task_id": task_id, "need_store": need_store, "initial_queries": [], "refined_queries": [], "result_queries": [], "knowledgeType": "", "query_type": "" } try: result = await self.graph.ainvoke(initial_state) return result["result_queries"], result["knowledgeType"], result["query_type"] except Exception as e: logger.error(f"生成查询词失败: {e}") # 更新任务状态为失败 if task_id > 0: self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED) # 降级处理:返回原始问题 return [question], "How" # 默认返回How类型 def is_tool_question(self, question: str) -> bool: """同步判断问题是否为工具知识类型。""" instruction = ( "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n" "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n" "- 内容知识:话题洞察、趋势、创作灵感、正文内容、案例分析、概念解释、非工具操作的问题。\n" "要求:严格只输出两个词之一——工具知识 或 内容知识;不要输出任何其它字符、解释或标点。" ) prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=instruction), HumanMessage(content=question) ]) try: response = self.llm.invoke(prompt.format_messages()) text = (response.content or "").strip() return "工具" in text except Exception: return False