from typing import List, Dict, Any, TypedDict, Annotated from langgraph.graph import StateGraph, END from langchain_google_genai import ChatGoogleGenerativeAI from langchain.prompts import ChatPromptTemplate from langchain.schema import HumanMessage, SystemMessage import httpx import json from ..tools.query_tool import SuggestQueryTool from ..tools.prompts import QUERY_GENERATION_PROMPT, QUERY_REFINEMENT_PROMPT from ..database.models import QueryTaskDAO, QueryTaskStatus, logger class AgentState(TypedDict): """Agent状态定义""" question: str task_id: int initial_queries: List[str] refined_queries: List[str] result_queries: List[Dict[str, str]] context: str iteration_count: int class QueryGenerationAgent: """查询词生成Agent""" def __init__(self, gemini_api_key: str, model_name: str = "gemini-1.5-pro"): """ 初始化Agent Args: gemini_api_key: Gemini API密钥 model_name: 使用的模型名称 """ self.llm = ChatGoogleGenerativeAI( google_api_key=gemini_api_key, model=model_name, temperature=0.7 ) self.query_tool = SuggestQueryTool() self.task_dao = QueryTaskDAO() # 创建状态图 self.graph = self._create_graph() def _create_graph(self) -> StateGraph: """创建LangGraph状态图""" workflow = StateGraph(AgentState) # 添加节点 workflow.add_node("analyze_question", self._analyze_question) workflow.add_node("generate_initial_queries", self._generate_initial_queries) workflow.add_node("refine_queries", self._refine_queries) workflow.add_node("validate_queries", self._validate_queries) workflow.add_node("classify_queries", self._classify_queries) workflow.add_node("save_queries", self._save_queries) # 设置入口点 workflow.set_entry_point("analyze_question") # 添加边 workflow.add_edge("analyze_question", "generate_initial_queries") workflow.add_edge("generate_initial_queries", "refine_queries") workflow.add_edge("refine_queries", "validate_queries") workflow.add_edge("validate_queries", "classify_queries") workflow.add_edge("classify_queries", "save_queries") workflow.add_edge("save_queries", END) return workflow.compile() def _analyze_question(self, state: AgentState) -> AgentState: """分析问题节点""" question = state["question"] # 分析问题的复杂度和类型 analysis_prompt = ChatPromptTemplate.from_messages([ SystemMessage(content="你是一个问题分析专家。请分析用户问题的类型和复杂度。"), HumanMessage(content=f"请分析这个问题:{question}\n\n分析要点:1.问题类型 2.复杂度 3.关键词 4.需要的查询角度") ]) try: response = self.llm.invoke(analysis_prompt.format_messages()) logger.info(f"问题分析结果: {response.content}") context = response.content except Exception as e: context = f"问题分析失败: {str(e)}" state["context"] = context state["iteration_count"] = 0 return state def _generate_initial_queries(self, state: AgentState) -> AgentState: """生成初始查询词节点""" question = state["question"] task_id = state["task_id"] context = state.get("context", "") # 使用工具生成查询词 try: initial_queries = self.query_tool._run(question, context, task_id) except Exception as e: # 如果工具失败,使用LLM生成 prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=QUERY_GENERATION_PROMPT), HumanMessage(content=question) ]) try: response = self.llm.invoke(prompt.format_messages()) queries_text = response.content initial_queries = [q.strip() for q in queries_text.split('\n') if q.strip()] except Exception: initial_queries = [question] # 降级处理 state["initial_queries"] = initial_queries return state def _refine_queries(self, state: AgentState) -> AgentState: """优化查询词节点""" question = state["question"] initial_queries = state["initial_queries"] if not initial_queries: state["refined_queries"] = [question] return state # 使用LLM优化查询词 queries_text = '\n'.join(initial_queries) prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=QUERY_REFINEMENT_PROMPT), HumanMessage(content=f"问题:{question}\n查询词:{queries_text}") ]) try: response = self.llm.invoke(prompt.format_messages()) logger.info(f"查询词优化结果: {response.content}") refined_text = response.content refined_queries = [q.strip() for q in refined_text.split('\n') if q.strip()] except Exception as e: # 如果优化失败,使用原始查询词 refined_queries = initial_queries state["refined_queries"] = refined_queries return state def _validate_queries(self, state: AgentState) -> AgentState: """验证查询词节点""" refined_queries = state["refined_queries"] # 基本验证:去重、过滤空字符串、限制长度 validated_queries = [] seen = set() for query in refined_queries: if query and len(query.strip()) > 0 and len(query.strip()) < 100: if query.strip() not in seen: validated_queries.append(query.strip()) seen.add(query.strip()) # 限制最终数量 if len(validated_queries) > 10: validated_queries = validated_queries[:10] # 确保至少有一个查询词 if not validated_queries: validated_queries = [state["question"]] logger.info(f"查询词验证结果: {validated_queries}") state["refined_queries"] = validated_queries return state def _classify_queries(self, state: AgentState) -> AgentState: """推测每个查询词的知识类型并写入result_queries""" refined_queries = state.get("refined_queries", []) # 使用大模型进行分类 result_items: List[Dict[str, str]] = self._classify_with_llm(refined_queries) state["result_queries"] = result_items return state def _save_queries(self, state: AgentState) -> AgentState: """保存查询词到外部接口节点""" refined_queries = state["refined_queries"] question = state["question"] if not refined_queries: logger.warning("没有查询词需要保存") return state # 调用外部接口保存查询词(按类型分组) try: url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery" headers = {"Content-Type": "application/json"} # 仅使用前一步的分类结果,不做即时分类 result_items: List[Dict[str, str]] = state.get("result_queries", []) if not result_items: logger.warning("缺少分类结果result_queries,跳过外部提交") return state if result_items: with httpx.Client() as client: data_content = {"queryWords": result_items} data_content = json.dumps(data_content) data_content = data_content.encode('utf-8') logger.info(f"查询词保存数据: {data_content}") resp1 = client.post(url, headers=headers, data=data_content, timeout=30) logger.info(f"查询词保存结果: {resp1.json()}") resp1.raise_for_status() logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}") except httpx.HTTPError as e: logger.error(f"保存查询词时发生HTTP错误: {str(e)}") except Exception as e: logger.error(f"保存查询词时发生错误: {str(e)}") return state def _infer_knowledge_type(self, query: str) -> str: """根据查询词简单推断知识类型(内容知识/工具知识)""" tool_keywords = [ "安装", "配置", "使用", "教程", "API", "SDK", "命令", "指令", "版本", "错误", "异常", "调试", "部署", "集成", "调用", "参数", "示例", "代码", "CLI", "tool", "library", "framework" ] lower_q = query.lower() for kw in tool_keywords: if kw.lower() in lower_q: return "工具知识" return "内容知识" def _classify_with_llm(self, queries: List[str]) -> List[Dict[str, str]]: """调用LLM将查询词分类为 内容知识 / 工具知识。 返回形如 [{"query": q, "knowledgeType": "内容知识"|"工具知识"}, ...] 若解析失败,降级为将所有查询标记为 内容知识(不使用关键词启发)。 """ if not queries: return [] instruction = ( "你是一名分类助手。请将下面的查询词逐一分类为‘内容知识’或‘工具知识’。\n" "请只返回严格的JSON数组,每个元素为对象:{\"query\": 原始查询词, \"knowledgeType\": \"内容知识\" 或 \"工具知识\"}。\n" "不要输出任何解释或多余文本。" ) payload = "\n".join(queries) prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=instruction), HumanMessage(content=f"查询词列表(每行一个):\n{payload}") ]) try: response = self.llm.invoke(prompt.format_messages()) text = (response.content or "").strip() logger.info(f"LLM分类结果: {text}") # 尝试解析为JSON数组;若失败,尝试从代码块或文本中提取 try: data = json.loads(text) except Exception: data = self._extract_json_array_from_text(text) result: List[Dict[str, str]] = [] for item in data: q = str(item.get("query", "")).strip() kt = str(item.get("knowledgeType", "")).strip() if q and kt in ("内容知识", "工具知识"): result.append({"query": q, "knowledgeType": kt}) # 保证顺序与输入一致,且都包含 if len(result) != len(queries): # 尝试基于输入进行对齐 mapped = {it["query"]: it["knowledgeType"] for it in result} aligned: List[Dict[str, str]] = [] for q in queries: kt = mapped.get(q, "内容知识") aligned.append({"query": q, "knowledgeType": kt}) return aligned return result except Exception as e: # 降级:全部标注为内容知识(不做关键词匹配) logger.warning(f"LLM分类失败,使用降级策略: {e}") return [{"query": q, "knowledgeType": "内容知识"} for q in queries] def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]: """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。""" s = (text or "").strip() # 去除三引号包裹的代码块 if s.startswith("```"): # 去掉第一行的 ``` 或 ```json first_newline = s.find('\n') if first_newline != -1: s = s[first_newline + 1:] if s.endswith("```"): s = s[:-3] s = s.strip() # 在文本中查找首个JSON数组 import re match = re.search(r"\[[\s\S]*\]", s) if not match: raise ValueError("未找到JSON数组片段") json_str = match.group(0) data = json.loads(json_str) if not isinstance(data, list): raise ValueError("提取内容不是JSON数组") return data async def generate_queries(self, question: str, task_id: int = 0) -> List[str]: """ 生成查询词的主入口 Args: question: 用户问题 task_id: 任务ID Returns: 生成的查询词列表 """ initial_state = { "question": question, "task_id": task_id, "initial_queries": [], "refined_queries": [], "result_queries": [], "context": "", "iteration_count": 0 } try: result = await self.graph.ainvoke(initial_state) return result["result_queries"] except Exception as e: # 更新任务状态为失败 if task_id > 0: self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED) # 降级处理:返回原始问题 return [question]