jihuaqiang 4 месяцев назад
Родитель
Сommit
bb003476a8
2 измененных файлов с 298 добавлено и 33 удалено
  1. 167 32
      src/agent/query_agent.py
  2. 131 1
      src/tools/prompts.py

+ 167 - 32
src/agent/query_agent.py

@@ -6,7 +6,11 @@ from langchain.schema import HumanMessage, SystemMessage
 import httpx
 import json
 
-from ..tools.prompts import STRUCTURED_TOOL_DEMAND_PROMPT
+from ..tools.prompts import (
+    STRUCTURED_TOOL_DEMAND_PROMPT,
+    CLASSIFICATION_PROMPT,
+    QUERY_CLASSIFICATION_PROMPT
+)
 from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
 
 
@@ -18,6 +22,8 @@ class AgentState(TypedDict):
     refined_queries: List[str]
     result_queries: List[Dict[str, str]]
     knowledgeType: str
+    content_dimension: str  # 内容类型的维度: How / What / Pattern
+    is_query_type: bool  # 是否为可处理的查询类型
 
 
 class QueryGenerationAgent:
@@ -46,29 +52,47 @@ class QueryGenerationAgent:
         """创建LangGraph状态图"""
         workflow = StateGraph(AgentState)
         
-        # 添加节点:分类 -> 生成 -> 保存
+        # 添加节点
         workflow.add_node("classify_question", self._classify_question)
-        workflow.add_node("generate_initial_queries", self._generate_initial_queries)
+        workflow.add_node("generate_tool_queries", self._generate_tool_queries)  # 工具类型查询生成
+        workflow.add_node("classify_content_dimension", self._classify_content_dimension)  # 内容维度分类
+        workflow.add_node("expand_content_queries", self._expand_content_queries)  # 内容查询扩展
         workflow.add_node("save_queries", self._save_queries)
         
         # 设置入口点
         workflow.set_entry_point("classify_question")
         
-        # 添加条件边:内容直接结束,工具进入生成
+        # 条件路由:工具知识 vs 内容知识
         try:
-            # 优先使用条件路由(若LangGraph版本支持)
             workflow.add_conditional_edges(
                 "classify_question",
                 self._route_after_classify,
                 {
-                    "TOOL": "generate_initial_queries",
-                    "CONTENT": END
+                    "TOOL": "generate_tool_queries",
+                    "CONTENT": "classify_content_dimension"
                 }
             )
         except Exception:
-            # 兼容:不支持条件边时,继续到生成节点,由生成节点自行拦截
-            workflow.add_edge("classify_question", "generate_initial_queries")
-        workflow.add_edge("generate_initial_queries", "save_queries")
+            workflow.add_edge("classify_question", "generate_tool_queries")
+        
+        # 工具类型:生成 -> 保存 -> 结束
+        workflow.add_edge("generate_tool_queries", "save_queries")
+        
+        # 内容类型:分类维度 -> 条件路由
+        try:
+            workflow.add_conditional_edges(
+                "classify_content_dimension",
+                self._route_after_content_classify,
+                {
+                    "EXPAND": "expand_content_queries",
+                    "UNSUPPORTED": END
+                }
+            )
+        except Exception:
+            workflow.add_edge("classify_content_dimension", "expand_content_queries")
+        
+        # 内容扩展:扩展 -> 保存 -> 结束
+        workflow.add_edge("expand_content_queries", "save_queries")
         workflow.add_edge("save_queries", END)
         
         return workflow.compile()
@@ -92,34 +116,18 @@ class QueryGenerationAgent:
             logger.info(f"问题类型判断结果: {text}")
             kt = "工具知识" if "工具" in text else "内容知识"
             state["knowledgeType"] = kt
-            # 若为内容,直接将任务标记为失败并准备结束
-            if kt != "工具知识":
-                try:
-                    if state.get("task_id", 0) > 0:
-                        self.task_dao.mark_task_failed(state["task_id"], "暂不支持非工具内容")
-                except Exception:
-                    pass
-                state["result_queries"] = []
-        except Exception:
+        except Exception as e:
             # 失败默认判为内容知识以避免误触发
+            logger.warning(f"问题类型判断失败: {e}")
             state["knowledgeType"] = "内容知识"
-            try:
-                if state.get("task_id", 0) > 0:
-                    self.task_dao.mark_task_failed(state["task_id"], "暂不支持非工具内容")
-            except Exception:
-                pass
-            state["result_queries"] = []
         return state
 
     def _route_after_classify(self, state: AgentState) -> str:
         """根据分类结果路由:工具 -> TOOL;内容 -> CONTENT"""
         return "TOOL" if state.get("knowledgeType") == "工具知识" else "CONTENT"
     
-    def _generate_initial_queries(self, state: AgentState) -> AgentState:
-        """生成 refined_queries(从结构化JSON中聚合三类关键词)"""
-        # 若为内容类型,直接抛出以在上层处理
-        if state.get("knowledgeType") != "工具知识":
-            raise ValueError("不支持内容类型问题")
+    def _generate_tool_queries(self, state: AgentState) -> AgentState:
+        """生成工具类型的查询词(从结构化JSON中聚合三类关键词)"""
         question = state["question"]
         # 使用新的结构化系统提示
         prompt = ChatPromptTemplate.from_messages([
@@ -160,7 +168,107 @@ class QueryGenerationAgent:
             state["refined_queries"] = [question]
         return state
     
-    # 删除 refine/validate/classify 节点
+    def _classify_content_dimension(self, state: AgentState) -> AgentState:
+        """使用CLASSIFICATION_PROMPT对内容类型问题进行维度分类(How/What/Pattern)"""
+        question = state["question"]
+        prompt = ChatPromptTemplate.from_messages([
+            SystemMessage(content=CLASSIFICATION_PROMPT),
+            HumanMessage(content=question)
+        ])
+        try:
+            response = self.llm.invoke(prompt.format_messages())
+            text = (response.content or "").strip()
+            logger.info(f"内容维度分类结果: {text}")
+            
+            # 解析JSON结果
+            try:
+                data = json.loads(text)
+            except Exception:
+                data = self._extract_json_from_text(text)
+            
+            dimension = data.get("所属维度", "").strip()
+            state["content_dimension"] = dimension
+            
+            # 判断是否为可处理的查询类型(目前仅支持How类型)
+            state["is_query_type"] = dimension == "How"
+            
+            if not state["is_query_type"]:
+                # 不支持的类型,标记任务失败
+                error_msg = f"暂不支持{dimension}类型的内容问题,当前仅支持How类型"
+                logger.info(error_msg)
+                if state.get("task_id", 0) > 0:
+                    self.task_dao.mark_task_failed(state["task_id"], error_msg)
+                state["result_queries"] = []
+        except Exception as e:
+            logger.error(f"内容维度分类失败: {e}")
+            state["is_query_type"] = False
+            if state.get("task_id", 0) > 0:
+                self.task_dao.mark_task_failed(state["task_id"], f"分类失败: {str(e)}")
+            state["result_queries"] = []
+        
+        return state
+    
+    def _route_after_content_classify(self, state: AgentState) -> str:
+        """根据内容分类结果路由:支持的类型 -> EXPAND;不支持 -> UNSUPPORTED"""
+        return "EXPAND" if state.get("is_query_type", False) else "UNSUPPORTED"
+    
+    def _expand_content_queries(self, state: AgentState) -> AgentState:
+        """使用QUERY_CLASSIFICATION_PROMPT扩展内容类型的查询词"""
+        question = state["question"]
+        prompt = ChatPromptTemplate.from_messages([
+            SystemMessage(content=QUERY_CLASSIFICATION_PROMPT),
+            HumanMessage(content=question)
+        ])
+        try:
+            response = self.llm.invoke(prompt.format_messages())
+            text = (response.content or "").strip()
+            logger.info(f"查询扩展结果: {text}")
+            
+            # 解析JSON结果
+            try:
+                data = json.loads(text)
+            except Exception:
+                data = self._extract_json_from_text(text)
+            
+            # 提取所有扩展的查询词
+            expanded = data.get("expanded_queries", {})
+            aggregated: List[str] = []
+            
+            # 收集粗颗粒度查询
+            for item in expanded.get("coarse_grained", []) or []:
+                q = str(item.get("query", "")).strip()
+                if q:
+                    aggregated.append(q)
+            
+            # 收集细颗粒度查询
+            for item in expanded.get("fine_grained", []) or []:
+                q = str(item.get("query", "")).strip()
+                if q:
+                    aggregated.append(q)
+            
+            # 收集互补或差异化查询
+            for item in expanded.get("complementary_or_differentiated", []) or []:
+                q = str(item.get("query", "")).strip()
+                if q:
+                    aggregated.append(q)
+            
+            # 去重,保持顺序
+            seen = set()
+            deduped: List[str] = []
+            for q in aggregated:
+                if q not in seen:
+                    seen.add(q)
+                    deduped.append(q)
+            
+            state["initial_queries"] = deduped
+            state["refined_queries"] = deduped
+            
+        except Exception as e:
+            logger.warning(f"查询扩展失败,降级为原始问题: {e}")
+            state["initial_queries"] = [question]
+            state["refined_queries"] = [question]
+        
+        return state
     
     def _save_queries(self, state: AgentState) -> AgentState:
         """保存查询词到外部接口节点"""
@@ -259,6 +367,29 @@ class QueryGenerationAgent:
             logger.warning(f"LLM分类失败,使用降级策略: {e}")
             return [{"query": q, "knowledgeType": "内容知识"} for q in queries]
 
+    def _extract_json_from_text(self, text: str) -> Dict[str, Any]:
+        """从模型输出中提取JSON对象(可能包含```json代码块或多余文本)"""
+        s = (text or "").strip()
+        # 去除三引号包裹的代码块
+        if s.startswith("```"):
+            # 去掉第一行的 ``` 或 ```json
+            first_newline = s.find('\n')
+            if first_newline != -1:
+                s = s[first_newline + 1:]
+            if s.endswith("```"):
+                s = s[:-3]
+            s = s.strip()
+        # 在文本中查找首个JSON对象
+        import re
+        match = re.search(r"\{[\s\S]*\}", s)
+        if not match:
+            raise ValueError("未找到JSON对象片段")
+        json_str = match.group(0)
+        data = json.loads(json_str)
+        if not isinstance(data, dict):
+            raise ValueError("提取内容不是JSON对象")
+        return data
+
     def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]:
         """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。"""
         s = (text or "").strip()
@@ -289,6 +420,7 @@ class QueryGenerationAgent:
         Args:
             question: 用户问题
             task_id: 任务ID
+            knowledge_type: 知识类型(可选,用于兼容)
         Returns:
             生成的查询词列表
         """
@@ -298,13 +430,16 @@ class QueryGenerationAgent:
             "initial_queries": [],
             "refined_queries": [],
             "result_queries": [],
-            "knowledgeType": "工具知识"
+            "knowledgeType": "",
+            "content_dimension": "",
+            "is_query_type": False
         }
         
         try:
             result = await self.graph.ainvoke(initial_state)
             return result["result_queries"]
         except Exception as e:
+            logger.error(f"生成查询词失败: {e}")
             # 更新任务状态为失败
             if task_id > 0:
                 self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)

+ 131 - 1
src/tools/prompts.py

@@ -41,4 +41,134 @@ how_to_use_queries: 操作/解决问题类关键词列表。
 
 请严格只输出JSON数组,不要包含任何额外文本或解释。
 """
- 
+
+CLASSIFICATION_PROMPT = """
+# 任务说明
+你是一个“内容创作拆解的 Query 分类器”。  
+你的任务是:根据输入的 query 词,结合定义文本,对 query 词进行精确分类,输出其在拆解标准中的归属。
+
+# 定义
+【How(怎么做 / 流程步骤)】
+- 创作从0到1及从1到爆款的具体动作和步骤。
+  - 前期:选题、趋势调研、受众画像、目标设定。
+  - 构思:脚本大纲、故事线、核心信息点、钩子设计、情绪调动。
+  - 制作:拍摄、录音、剪辑、配图、配文、排版、视觉风格、配乐。
+  - 发布:平台选择、发布时间、标题与封面、话题标签、互动策略。
+  - 优化:数据监测、用户反馈、A/B 测试、二次传播、规模化复制。
+
+【What(创作要素 / 内容里有什么)】
+- 名词:独立的主体或元素,如角色、场景、物件、工具、抽象概念。
+- 动词:动作或过程,如展示、转折、对比、讲解、互动。
+- 形容词:状态或属性,如美丽、梦幻、真实、幽默、夸张。
+- 副词:程度或方式,如非常、快速、缓慢、极端、对比地。
+- 补充要素:钩子、冲突、故事结构、节奏感、视觉风格、声音与配乐、符号、品牌调性、传播元素(标题、封面、标签)、感知要素(共鸣感、差异感、信任感)。
+
+【内容pattern(内容模式 / 内容范式)】
+- 定义:内容 Pattern 是指在内容创作、传播与消费全链路中,基于用户认知习惯与内容目标形成的、可复用的规律性结构 / 逻辑框架。
+    - 其核心属性包括:
+    - 1. 规律性(贴合用户信息接收逻辑,如对冲突、故事的天然敏感度,非随机设计);
+    - 2. 目的性(服务明确内容目标,如知识传递、情感共鸣、传播裂变等);
+    - 3. 可复制性(提供 “框架骨架”,允许填充差异化细节,实现 “形同质异”),本质是对高效果内容经验的提炼,用于降低创作试错成本、提升内容与目标受众的匹配效率。
+
+
+# Loop 机制
+1. 初步尝试分类:将 query 词放入上述的 How / What / pattern 任一环节。  
+2. 如果无法直接分类 → 启动 Loop:结合目标(短视频 & 图文创作的全流程拆解,从0到1到爆款,让小白能理解所有题材的方法),重新分析 query 词的语义,再次判断最合适的分类。  
+3. 必须给出最终分类,不允许保留“未分类”或模糊标签。
+
+# 输出格式
+仅输出以下 JSON 格式(严格保持):
+
+{
+  "query": "输入的 query 词",
+  "所属维度": "How / What / pattern",
+  "分类说明": "简要说明分类依据或理由(如语义倾向、关键词特征、关联动作等)"
+}
+"""
+
+QUERY_CLASSIFICATION_PROMPT = """
+# 系统角色与目标
+你是一个“内容创作领域的Query扩展专家”。  
+你的目标是:针对输入的query问题,生成一组高质量的扩展query词,用于查找与内容创作相关的有效知识,并服务于整体目标——构建内容创作知识库,帮助小白用户理解和应用。
+
+---
+
+# 背景与知识框架
+内容创作知识主要分为三类:How、What、Pattern。  
+本次任务重点是 **How(怎么做 / 流程步骤)**,即创作从0到1及从1到爆款的具体动作和步骤。  
+- 包括:前期选题、趋势调研、受众画像、目标设定;构思脚本、故事线、钩子设计;制作拍摄、剪辑、配图、配文、排版、视觉风格、配乐;发布平台选择、标题封面、标签、互动策略;优化数据监测、用户反馈、A/B测试、二次传播、规模化复制。
+
+---
+
+# 任务说明
+1. 输入:用户提供的一个query问题(与内容创作相关)。  
+2. 输出:  
+   1)原问题;  
+   2)扩展query词(**可分多级**,从粗颗粒度到细颗粒度,视问题复杂性而定);  
+   3)每个扩展query词的扩展原因(说明生成方法与逻辑,方便后续迭代分析);
+   4)每组扩展的query词最多保留**最精品的1-3组query问题**。  
+
+---
+
+# 操作步骤
+1. **理解输入query**  
+   - 分析query的意图和目标;  
+   - 判断query属于内容创作的哪类知识(本次为How类);  
+
+2. **多级query扩展**  
+   - **粗颗粒度**:从方法论角度概括原问题,生成泛化query词;  
+   - **细颗粒度**:结合具体场景、工具、步骤等细化query;  
+   - 可根据需要生成多种分级query词,确保覆盖不同细化程度;  
+   - 参考<示例>:  
+     <示例>
+     {原问题:萌宠喵咪的表情包的选题来源思路是怎么样的?
+     粗颗粒度:选题来源 / 选题思路
+     细颗粒度:制作萌宠类表情包的选题思路 / 喵咪表情包如何做选题}
+     </示例>
+
+3. **Loop审视**  
+   - 检查初步生成的query是否与原问题主题高度一致;  
+   - 分析query是否存在互补或差异化角度,也可实现相同目标;  
+   - 保留高质量query,去除无关或重复的query;  
+
+4. **扩展原因说明**  
+   - 每个query词都需附上生成方法或逻辑,如“由粗颗粒度抽象而来”“结合具体场景细化而来”“作为互补角度拓展而来”等。
+
+---
+
+# 输出格式
+请严格按照以下格式输出:
+
+    "instruction": "请严格按照以下JSON格式进行输出,不要在JSON代码块前后添加任何额外的解释或说明文字。",
+    "example_structure": {
+      "original_query": "在此处填写用户输入的原始问题",
+      "expanded_queries": {
+        "coarse_grained": [
+          {
+            "query": "<粗颗粒度扩展出的query词1>",
+            "reason": "<扩展原因与方法1,例如:从方法论层面进行抽象概括>"
+          },
+          {
+            "query": "<粗颗粒度扩展出的query词2>",
+            "reason": "<扩展原因与方法2>"
+          }
+        ],
+        "fine_grained": [
+          {
+            "query": "<细颗粒度扩展出的query词1>",
+            "reason": "<扩展原因与方法1,例如:结合具体场景'...'进行细化>"
+          },
+          {
+            "query": "<细颗粒度扩展出的query词2>",
+            "reason": "<扩展原因与方法2>"
+          }
+        ],
+        "complementary_or_differentiated": [
+          {
+            "query": "<互补或差异化角度的query词1>",
+            "reason": "<扩展原因与方法1,例如:提供一个互补角度'...'的思路>"
+          }
+        ]
+      }
+    }
+"""