| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523 |
- from typing import List, Dict, Any, TypedDict
- from langgraph.graph import StateGraph, END
- from langchain_google_genai import ChatGoogleGenerativeAI
- from langchain.prompts import ChatPromptTemplate
- from langchain.schema import HumanMessage, SystemMessage
- import httpx
- import json
- from ..tools.prompts import (
- STRUCTURED_TOOL_DEMAND_PROMPT,
- CLASSIFICATION_PROMPT,
- QUERY_CLASSIFICATION_PROMPT,
- WHAT_CLASSIFICATION_PROMPT,
- PATTERN_CLASSIFICATION_PROMPT
- )
- from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
- class AgentState(TypedDict):
- """Agent状态定义"""
- question: str
- task_id: int
- need_store: int
- initial_queries: List[str]
- refined_queries: List[str]
- result_queries: List[Dict[str, str]]
- knowledgeType: str
- query_type: str # 问题类型: How / What / Pattern
- class QueryGenerationAgent:
- """查询词生成Agent"""
-
- def __init__(self, gemini_api_key: str, model_name: str = "gemini-1.5-pro"):
- """
- 初始化Agent
-
- Args:
- gemini_api_key: Gemini API密钥
- model_name: 使用的模型名称
- """
- self.llm = ChatGoogleGenerativeAI(
- google_api_key=gemini_api_key,
- model=model_name,
- temperature=0.7
- )
-
- self.task_dao = QueryTaskDAO()
-
- # 创建状态图
- self.graph = self._create_graph()
-
- def _normalize_query_type(self, query_type: str) -> str:
- """统一规范化query_type为首字母大写格式(How/What/Pattern)"""
- query_type_lower = query_type.strip().lower()
- if query_type_lower == "how":
- return "How"
- elif query_type_lower == "what":
- return "What"
- elif query_type_lower == "pattern":
- return "Pattern"
- else:
- return query_type # 返回原值
-
- def _create_graph(self) -> StateGraph:
- """创建LangGraph状态图"""
- workflow = StateGraph(AgentState)
-
- # 添加节点
- workflow.add_node("classify_question", self._classify_question)
- workflow.add_node("generate_tool_queries", self._generate_tool_queries) # 工具类型查询生成
- workflow.add_node("classify_content_dimension", self._classify_content_dimension) # 内容维度分类
- workflow.add_node("expand_content_queries", self._expand_content_queries) # 内容查询扩展
- workflow.add_node("save_queries", self._save_queries)
-
- # 设置入口点
- workflow.set_entry_point("classify_question")
-
- # 条件路由:工具知识 vs 内容知识
- try:
- workflow.add_conditional_edges(
- "classify_question",
- self._route_after_classify,
- {
- "TOOL": "generate_tool_queries",
- "CONTENT": "classify_content_dimension"
- }
- )
- except Exception:
- workflow.add_edge("classify_question", "generate_tool_queries")
-
- # 工具类型:生成 -> 保存 -> 结束
- workflow.add_edge("generate_tool_queries", "save_queries")
-
- # 内容类型:分类维度 -> 条件路由
- try:
- workflow.add_conditional_edges(
- "classify_content_dimension",
- self._route_after_content_classify,
- {
- "EXPAND": "expand_content_queries",
- "UNSUPPORTED": END
- }
- )
- except Exception:
- workflow.add_edge("classify_content_dimension", "expand_content_queries")
-
- # 内容扩展:扩展 -> 保存 -> 结束
- workflow.add_edge("expand_content_queries", "save_queries")
- workflow.add_edge("save_queries", END)
-
- return workflow.compile()
- def _classify_question(self, state: AgentState) -> AgentState:
- """判断问题知识类型:工具知识 / 内容知识"""
- question = state.get("question", "")
- instruction = (
- "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n"
- "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n"
- "- 内容知识:话题洞察、趋势、创作灵感、正文内容、案例分析、概念解释、非工具操作的问题。\n"
- "要求:严格只输出两个词之一——工具知识 或 内容知识;不要输出任何其它字符、解释或标点。"
- )
- prompt = ChatPromptTemplate.from_messages([
- SystemMessage(content=instruction),
- HumanMessage(content=question)
- ])
- try:
- response = self.llm.invoke(prompt.format_messages())
- text = (response.content or "").strip()
- logger.info(f"问题类型判断结果: {text}")
- kt = "工具知识" if "工具" in text else "内容知识"
- state["knowledgeType"] = kt
- except Exception as e:
- # 失败默认判为内容知识以避免误触发
- logger.warning(f"问题类型判断失败: {e}")
- state["knowledgeType"] = "内容知识"
- return state
- def _route_after_classify(self, state: AgentState) -> str:
- """根据分类结果路由:工具 -> TOOL;内容 -> CONTENT"""
- return "TOOL" if state.get("knowledgeType") == "工具知识" else "CONTENT"
-
- def _generate_tool_queries(self, state: AgentState) -> AgentState:
- """生成工具类型的查询词(从结构化JSON中聚合三类关键词)"""
- question = state["question"]
- # 使用新的结构化系统提示
- prompt = ChatPromptTemplate.from_messages([
- SystemMessage(content=STRUCTURED_TOOL_DEMAND_PROMPT),
- HumanMessage(content=question)
- ])
- try:
- response = self.llm.invoke(prompt.format_messages())
- text = (response.content or "").strip()
- # 解析严格的JSON数组;若失败,尝试从文本中提取
- try:
- data = json.loads(text)
- except Exception:
- data = self._extract_json_array_from_text(text)
- logger.info(f"需求分析结果: {data}")
- aggregated: List[str] = []
- for item in data:
- ek = (item or {}).get("expanded_keywords", {})
- g = ek.get("general_discovery_queries", []) or []
- t = ek.get("themed_function_queries", []) or []
- h = ek.get("how_to_use_queries", []) or []
- for q in [*g, *t, *h]:
- q_str = str(q).strip()
- if q_str:
- aggregated.append(q_str)
- # 去重,保持顺序
- seen = set()
- deduped: List[str] = []
- for q in aggregated:
- if q not in seen:
- seen.add(q)
- deduped.append(q)
- state["initial_queries"] = deduped
- state["refined_queries"] = deduped
- except Exception as e:
- logger.warning(f"结构化需求解析失败,降级为原始问题: {e}")
- state["initial_queries"] = [question]
- state["refined_queries"] = [question]
- return state
-
- def _classify_content_dimension(self, state: AgentState) -> AgentState:
- """使用CLASSIFICATION_PROMPT对内容类型问题进行维度分类(How/What/Pattern)"""
- question = state["question"]
- prompt = ChatPromptTemplate.from_messages([
- SystemMessage(content=CLASSIFICATION_PROMPT),
- HumanMessage(content=question)
- ])
- try:
- response = self.llm.invoke(prompt.format_messages())
- text = (response.content or "").strip()
- logger.info(f"内容维度分类结果: {text}")
-
- # 解析JSON结果
- try:
- data = json.loads(text)
- except Exception:
- data = self._extract_json_from_text(text)
-
- dimension = data.get("所属维度", "").strip()
- # 统一为首字母大写格式(How/What/Pattern)
- dimension = self._normalize_query_type(dimension)
- state["query_type"] = dimension
- logger.info(f"问题类型设置为: {dimension}")
-
- except Exception as e:
- logger.error(f"内容维度分类失败: {e}")
- if state.get("task_id", 0) > 0:
- self.task_dao.mark_task_failed(state["task_id"], f"分类失败: {str(e)}")
- state["result_queries"] = []
-
- return state
-
- def _route_after_content_classify(self, state: AgentState) -> str:
- """根据内容分类结果路由:所有类型都支持扩展"""
- query_type = state.get("query_type", "")
- # 支持 How / What / Pattern 三种类型
- if query_type in ["How", "What", "Pattern"]:
- return "EXPAND"
- else:
- # 未识别的类型,不支持
- logger.warning(f"未识别的问题类型: {query_type}")
- return "UNSUPPORTED"
-
- def _expand_content_queries(self, state: AgentState) -> AgentState:
- """根据问题类型选择相应的PROMPT扩展内容查询词"""
- question = state["question"]
- query_type = state.get("query_type", "How")
-
- # 根据query_type选择对应的PROMPT(值已在分类阶段规范化为How/What/Pattern)
- if query_type == "How":
- classification_prompt = QUERY_CLASSIFICATION_PROMPT
- elif query_type == "What":
- classification_prompt = WHAT_CLASSIFICATION_PROMPT
- elif query_type == "Pattern":
- classification_prompt = PATTERN_CLASSIFICATION_PROMPT
- else:
- # 默认使用How类型的PROMPT
- classification_prompt = QUERY_CLASSIFICATION_PROMPT
- logger.warning(f"未识别的问题类型 {query_type},使用默认How类型PROMPT")
-
- logger.info(f"使用{query_type}类型的PROMPT进行查询扩展")
-
- prompt = ChatPromptTemplate.from_messages([
- SystemMessage(content=classification_prompt),
- HumanMessage(content=question)
- ])
- try:
- response = self.llm.invoke(prompt.format_messages())
- text = (response.content or "").strip()
- logger.info(f"查询扩展结果: {text}")
-
- # 解析JSON结果
- try:
- data = json.loads(text)
- except Exception:
- data = self._extract_json_from_text(text)
-
- # 提取所有扩展的查询词
- expanded = data.get("expanded_queries", {})
- aggregated: List[str] = []
- invalid_keywords = ["无关", "超出", "不相关", "不属于", "无法生成"]
-
- # 收集粗颗粒度查询并检测是否不符合创作领域
- for item in expanded.get("coarse_grained", []) or []:
- q = str(item.get("query", "")).strip()
- reason = str(item.get("reason", "")).strip()
-
- # 检测是否表明问题不符合创作领域
- if q and any(keyword in q for keyword in invalid_keywords):
- error_msg = q if len(q) <= 100 else reason[:100] if reason else "问题不符合内容创作领域"
- logger.info(f"检测到不符合创作领域的问题: {error_msg}")
- if state.get("task_id", 0) > 0:
- self.task_dao.mark_task_failed(state["task_id"], error_msg)
- state["result_queries"] = []
- state["initial_queries"] = []
- state["refined_queries"] = []
- return state
-
- if q:
- aggregated.append(q)
-
- # 收集细颗粒度查询
- for item in expanded.get("fine_grained", []) or []:
- q = str(item.get("query", "")).strip()
- if q:
- aggregated.append(q)
-
- # 收集互补或差异化查询
- for item in expanded.get("complementary_or_differentiated", []) or []:
- q = str(item.get("query", "")).strip()
- if q:
- aggregated.append(q)
-
- # 如果所有查询词都为空,可能表示无法生成有效查询
- if not aggregated:
- error_msg = "无法生成有效的内容创作查询词"
- logger.info(error_msg)
- if state.get("task_id", 0) > 0:
- self.task_dao.mark_task_failed(state["task_id"], error_msg)
- state["result_queries"] = []
- state["initial_queries"] = []
- state["refined_queries"] = []
- return state
-
- # 去重,保持顺序
- seen = set()
- deduped: List[str] = []
- for q in aggregated:
- if q not in seen:
- seen.add(q)
- deduped.append(q)
-
- state["initial_queries"] = deduped
- state["refined_queries"] = deduped
-
- except Exception as e:
- logger.warning(f"查询扩展失败,降级为原始问题: {e}")
- state["initial_queries"] = [question]
- state["refined_queries"] = [question]
-
- return state
-
- def _save_queries(self, state: AgentState) -> AgentState:
- """保存查询词到外部接口节点"""
- refined_queries = state.get("refined_queries", [])
- question = state.get("question", "")
- knowledge_type = state.get("knowledgeType", "") or "内容知识"
-
- if not refined_queries:
- logger.warning("没有查询词需要保存")
- return state
-
- # 合并 knowledgeType 与每个查询词,附加 task_id,形成提交数据
- result_items: List[Dict[str, str]] = [
- {"query": q, "knowledgeType": knowledge_type, "task_id": state.get("task_id", 0)} for q in refined_queries
- ]
- state["result_queries"] = result_items
-
- # need_store=1 保存查询词
- if state.get("need_store", 1) == 1:
- try:
- url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
- headers = {"Content-Type": "application/json"}
- with httpx.Client() as client:
- data_content = result_items
- logger.info(f"查询词保存数据: {data_content}")
- resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
- resp1.raise_for_status()
- logger.info(f"查询词保存结果: {resp1.text}")
- logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
- except httpx.HTTPError as e:
- logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
- except Exception as e:
- logger.error(f"保存查询词时发生错误: {str(e)}")
-
- return state
-
- def _infer_knowledge_type(self, query: str) -> str:
- """根据查询词简单推断知识类型(内容知识/工具知识)"""
- tool_keywords = [
- "安装", "配置", "使用", "教程", "API", "SDK", "命令", "指令", "版本",
- "错误", "异常", "调试", "部署", "集成", "调用", "参数", "示例", "代码",
- "CLI", "tool", "library", "framework"
- ]
- lower_q = query.lower()
- for kw in tool_keywords:
- if kw.lower() in lower_q:
- return "工具知识"
- return "内容知识"
- def _classify_with_llm(self, queries: List[str]) -> List[Dict[str, str]]:
- """调用LLM将查询词分类为 内容知识 / 工具知识。
- 返回形如 [{"query": q, "knowledgeType": "内容知识"|"工具知识"}, ...]
- 若解析失败,降级为将所有查询标记为 内容知识(不使用关键词启发)。
- """
- if not queries:
- return []
- instruction = (
- "你是一名分类助手。请将下面的查询词逐一分类为‘内容知识’或‘工具知识’。\n"
- "请只返回严格的JSON数组,每个元素为对象:{\"query\": 原始查询词, \"knowledgeType\": \"内容知识\" 或 \"工具知识\"}。\n"
- "不要输出任何解释或多余文本。"
- )
- payload = "\n".join(queries)
- prompt = ChatPromptTemplate.from_messages([
- SystemMessage(content=instruction),
- HumanMessage(content=f"查询词列表(每行一个):\n{payload}")
- ])
- try:
- response = self.llm.invoke(prompt.format_messages())
- text = (response.content or "").strip()
- logger.info(f"LLM分类结果: {text}")
- # 尝试解析为JSON数组;若失败,尝试从代码块或文本中提取
- try:
- data = json.loads(text)
- except Exception:
- data = self._extract_json_array_from_text(text)
- result: List[Dict[str, str]] = []
- for item in data:
- q = str(item.get("query", "")).strip()
- kt = str(item.get("knowledgeType", "")).strip()
- if q and kt in ("内容知识", "工具知识"):
- result.append({"query": q, "knowledgeType": kt})
- # 保证顺序与输入一致,且都包含
- if len(result) != len(queries):
- # 尝试基于输入进行对齐
- mapped = {it["query"]: it["knowledgeType"] for it in result}
- aligned: List[Dict[str, str]] = []
- for q in queries:
- kt = mapped.get(q, "内容知识")
- aligned.append({"query": q, "knowledgeType": kt})
- return aligned
- return result
- except Exception as e:
- # 降级:全部标注为内容知识(不做关键词匹配)
- logger.warning(f"LLM分类失败,使用降级策略: {e}")
- return [{"query": q, "knowledgeType": "内容知识"} for q in queries]
- def _extract_json_from_text(self, text: str) -> Dict[str, Any]:
- """从模型输出中提取JSON对象(可能包含```json代码块或多余文本)"""
- s = (text or "").strip()
- # 去除三引号包裹的代码块
- if s.startswith("```"):
- # 去掉第一行的 ``` 或 ```json
- first_newline = s.find('\n')
- if first_newline != -1:
- s = s[first_newline + 1:]
- if s.endswith("```"):
- s = s[:-3]
- s = s.strip()
- # 在文本中查找首个JSON对象
- import re
- match = re.search(r"\{[\s\S]*\}", s)
- if not match:
- raise ValueError("未找到JSON对象片段")
- json_str = match.group(0)
- data = json.loads(json_str)
- if not isinstance(data, dict):
- raise ValueError("提取内容不是JSON对象")
- return data
- def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]:
- """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。"""
- s = (text or "").strip()
- # 去除三引号包裹的代码块
- if s.startswith("```"):
- # 去掉第一行的 ``` 或 ```json
- first_newline = s.find('\n')
- if first_newline != -1:
- s = s[first_newline + 1:]
- if s.endswith("```"):
- s = s[:-3]
- s = s.strip()
- # 在文本中查找首个JSON数组
- import re
- match = re.search(r"\[[\s\S]*\]", s)
- if not match:
- raise ValueError("未找到JSON数组片段")
- json_str = match.group(0)
- data = json.loads(json_str)
- if not isinstance(data, list):
- raise ValueError("提取内容不是JSON数组")
- return data
- async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> tuple[List[str], str, str]:
- """
- 生成查询词的主入口
-
- Args:
- question: 用户问题
- task_id: 任务ID
- knowledge_type: 知识类型(可选,用于兼容)
- Returns:
- 元组:(生成的查询词列表, 问题类型)
- """
- initial_state = {
- "question": question,
- "task_id": task_id,
- "need_store": need_store,
- "initial_queries": [],
- "refined_queries": [],
- "result_queries": [],
- "knowledgeType": "",
- "query_type": ""
- }
-
- try:
- result = await self.graph.ainvoke(initial_state)
- return result["result_queries"], result["knowledgeType"], result["query_type"]
- except Exception as e:
- logger.error(f"生成查询词失败: {e}")
- # 更新任务状态为失败
- if task_id > 0:
- self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
- # 降级处理:返回原始问题
- return [question], "How" # 默认返回How类型
- def is_tool_question(self, question: str) -> bool:
- """同步判断问题是否为工具知识类型。"""
- instruction = (
- "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n"
- "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n"
- "- 内容知识:话题洞察、趋势、创作灵感、正文内容、案例分析、概念解释、非工具操作的问题。\n"
- "要求:严格只输出两个词之一——工具知识 或 内容知识;不要输出任何其它字符、解释或标点。"
- )
- prompt = ChatPromptTemplate.from_messages([
- SystemMessage(content=instruction),
- HumanMessage(content=question)
- ])
- try:
- response = self.llm.invoke(prompt.format_messages())
- text = (response.content or "").strip()
- return "工具" in text
- except Exception:
- return False
|