Kaynağa Gözat

流程调整

jihuaqiang 1 hafta önce
ebeveyn
işleme
e9ea9d1439

+ 58 - 138
src/agent/query_agent.py

@@ -1,4 +1,4 @@
-from typing import List, Dict, Any, TypedDict, Annotated
+from typing import List, Dict, Any, TypedDict
 from langgraph.graph import StateGraph, END
 from langchain_google_genai import ChatGoogleGenerativeAI
 from langchain.prompts import ChatPromptTemplate
@@ -6,8 +6,7 @@ from langchain.schema import HumanMessage, SystemMessage
 import httpx
 import json
 
-from ..tools.query_tool import SuggestQueryTool
-from ..tools.prompts import QUERY_GENERATION_PROMPT, QUERY_REFINEMENT_PROMPT
+from ..tools.prompts import STRUCTURED_TOOL_DEMAND_PROMPT
 from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
 
 
@@ -18,8 +17,7 @@ class AgentState(TypedDict):
     initial_queries: List[str]
     refined_queries: List[str]
     result_queries: List[Dict[str, str]]
-    context: str
-    iteration_count: int
+    knowledgeType: str
 
 
 class QueryGenerationAgent:
@@ -39,7 +37,6 @@ class QueryGenerationAgent:
             temperature=0.7
         )
         
-        self.query_tool = SuggestQueryTool()
         self.task_dao = QueryTaskDAO()
         
         # 创建状态图
@@ -49,166 +46,89 @@ class QueryGenerationAgent:
         """创建LangGraph状态图"""
         workflow = StateGraph(AgentState)
         
-        # 添加节点
-        workflow.add_node("analyze_question", self._analyze_question)
+        # 添加节点(仅保留 生成 与 保存)
         workflow.add_node("generate_initial_queries", self._generate_initial_queries)
-        workflow.add_node("refine_queries", self._refine_queries)
-        workflow.add_node("validate_queries", self._validate_queries)
-        workflow.add_node("classify_queries", self._classify_queries)
         workflow.add_node("save_queries", self._save_queries)
         
         # 设置入口点
-        workflow.set_entry_point("analyze_question")
+        workflow.set_entry_point("generate_initial_queries")
         
         # 添加边
-        workflow.add_edge("analyze_question", "generate_initial_queries")
-        workflow.add_edge("generate_initial_queries", "refine_queries")
-        workflow.add_edge("refine_queries", "validate_queries")
-        workflow.add_edge("validate_queries", "classify_queries")
-        workflow.add_edge("classify_queries", "save_queries")
+        workflow.add_edge("generate_initial_queries", "save_queries")
         workflow.add_edge("save_queries", END)
         
         return workflow.compile()
     
-    def _analyze_question(self, state: AgentState) -> AgentState:
-        """分析问题节点"""
-        question = state["question"]
-        
-        # 分析问题的复杂度和类型
-        analysis_prompt = ChatPromptTemplate.from_messages([
-            SystemMessage(content="你是一个问题分析专家。请分析用户问题的类型和复杂度。"),
-            HumanMessage(content=f"请分析这个问题:{question}\n\n分析要点:1.问题类型 2.复杂度 3.关键词 4.需要的查询角度")
-        ])
-        
-        try:
-            response = self.llm.invoke(analysis_prompt.format_messages())
-            logger.info(f"问题分析结果: {response.content}")
-            context = response.content
-        except Exception as e:
-            context = f"问题分析失败: {str(e)}"
-        
-        state["context"] = context
-        state["iteration_count"] = 0
-        
-        return state
-    
     def _generate_initial_queries(self, state: AgentState) -> AgentState:
-        """生成初始查询词节点"""
-        question = state["question"]
-        task_id = state["task_id"]
-        context = state.get("context", "")
-        
-        # 使用工具生成查询词
-        try:
-            initial_queries = self.query_tool._run(question, context, task_id)
-        except Exception as e:
-            # 如果工具失败,使用LLM生成
-            prompt = ChatPromptTemplate.from_messages([
-                SystemMessage(content=QUERY_GENERATION_PROMPT),
-                HumanMessage(content=question)
-            ])
-            
-            try:
-                response = self.llm.invoke(prompt.format_messages())
-                queries_text = response.content
-                initial_queries = [q.strip() for q in queries_text.split('\n') if q.strip()]
-            except Exception:
-                initial_queries = [question]  # 降级处理
-        
-        state["initial_queries"] = initial_queries
-        return state
-    
-    def _refine_queries(self, state: AgentState) -> AgentState:
-        """优化查询词节点"""
+        """生成 refined_queries(从结构化JSON中聚合三类关键词)"""
         question = state["question"]
-        initial_queries = state["initial_queries"]
-        
-        if not initial_queries:
-            state["refined_queries"] = [question]
-            return state
-        
-        # 使用LLM优化查询词
-        queries_text = '\n'.join(initial_queries)
+        # 使用新的结构化系统提示
         prompt = ChatPromptTemplate.from_messages([
-            SystemMessage(content=QUERY_REFINEMENT_PROMPT),
-            HumanMessage(content=f"问题:{question}\n查询词:{queries_text}")
+            SystemMessage(content=STRUCTURED_TOOL_DEMAND_PROMPT),
+            HumanMessage(content=question)
         ])
-        
         try:
             response = self.llm.invoke(prompt.format_messages())
-            logger.info(f"查询词优化结果: {response.content}")
-            refined_text = response.content
-            refined_queries = [q.strip() for q in refined_text.split('\n') if q.strip()]
+            text = (response.content or "").strip()
+            # 解析严格的JSON数组;若失败,尝试从文本中提取
+            try:
+                data = json.loads(text)
+            except Exception:
+                data = self._extract_json_array_from_text(text)
+            logger.info(f"需求分析结果: {data}")
+            aggregated: List[str] = []
+            for item in data:
+                ek = (item or {}).get("expanded_keywords", {})
+                g = ek.get("general_discovery_queries", []) or []
+                t = ek.get("themed_function_queries", []) or []
+                h = ek.get("how_to_use_queries", []) or []
+                for q in [*g, *t, *h]:
+                    q_str = str(q).strip()
+                    if q_str:
+                        aggregated.append(q_str)
+            # 去重,保持顺序
+            seen = set()
+            deduped: List[str] = []
+            for q in aggregated:
+                if q not in seen:
+                    seen.add(q)
+                    deduped.append(q)
+            state["initial_queries"] = deduped
+            state["refined_queries"] = deduped
         except Exception as e:
-            # 如果优化失败,使用原始查询词
-            refined_queries = initial_queries
-        
-        state["refined_queries"] = refined_queries
+            logger.warning(f"结构化需求解析失败,降级为原始问题: {e}")
+            state["initial_queries"] = [question]
+            state["refined_queries"] = [question]
         return state
     
-    def _validate_queries(self, state: AgentState) -> AgentState:
-        """验证查询词节点"""
-        refined_queries = state["refined_queries"]
-        
-        # 基本验证:去重、过滤空字符串、限制长度
-        validated_queries = []
-        seen = set()
-        
-        for query in refined_queries:
-            if query and len(query.strip()) > 0 and len(query.strip()) < 100:
-                if query.strip() not in seen:
-                    validated_queries.append(query.strip())
-                    seen.add(query.strip())
-        
-        # 限制最终数量
-        if len(validated_queries) > 10:
-            validated_queries = validated_queries[:10]
-        
-        # 确保至少有一个查询词
-        if not validated_queries:
-            validated_queries = [state["question"]]
-        logger.info(f"查询词验证结果: {validated_queries}")
-        state["refined_queries"] = validated_queries
-        return state
-
-    def _classify_queries(self, state: AgentState) -> AgentState:
-        """推测每个查询词的知识类型并写入result_queries"""
-        refined_queries = state.get("refined_queries", [])
-        # 使用大模型进行分类
-        result_items: List[Dict[str, str]] = self._classify_with_llm(refined_queries)
-        state["result_queries"] = result_items
-        return state
+    # 删除 refine/validate/classify 节点
     
     def _save_queries(self, state: AgentState) -> AgentState:
         """保存查询词到外部接口节点"""
-        refined_queries = state["refined_queries"]
-        question = state["question"]
+        refined_queries = state.get("refined_queries", [])
+        question = state.get("question", "")
+        knowledge_type = state.get("knowledgeType", "") or "内容知识"
         
         if not refined_queries:
             logger.warning("没有查询词需要保存")
             return state
         
-        # 调用外部接口保存查询词(按类型分组)
+        # 合并 knowledgeType 与每个查询词,形成提交数据
+        result_items: List[Dict[str, str]] = [
+            {"query": q, "knowledgeType": knowledge_type} for q in refined_queries
+        ]
+        state["result_queries"] = result_items
+        
         try:
             url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
             headers = {"Content-Type": "application/json"}
-
-            # 仅使用前一步的分类结果,不做即时分类
-            result_items: List[Dict[str, str]] = state.get("result_queries", [])
-            if not result_items:
-                logger.warning("缺少分类结果result_queries,跳过外部提交")
-                return state
-
-            if result_items:
-                with httpx.Client() as client:
-                    data_content = result_items
-                    logger.info(f"查询词保存数据: {data_content}")
-                    resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
-                    resp1.raise_for_status()
-                    logger.info(f"查询词保存结果: {resp1.text}")
-
+            with httpx.Client() as client:
+                data_content = result_items
+                logger.info(f"查询词保存数据: {data_content}")
+                # resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
+                # resp1.raise_for_status()
+                # logger.info(f"查询词保存结果: {resp1.text}")
             logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
-
         except httpx.HTTPError as e:
             logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
         except Exception as e:
@@ -303,7 +223,7 @@ class QueryGenerationAgent:
             raise ValueError("提取内容不是JSON数组")
         return data
 
-    async def generate_queries(self, question: str, task_id: int = 0) -> List[str]:
+    async def generate_queries(self, question: str, task_id: int = 0, knowledge_type: str = "") -> List[str]:
         """
         生成查询词的主入口
         
@@ -319,8 +239,7 @@ class QueryGenerationAgent:
             "initial_queries": [],
             "refined_queries": [],
             "result_queries": [],
-            "context": "",
-            "iteration_count": 0
+            "knowledgeType": knowledge_type or "内容知识"
         }
         
         try:
@@ -332,3 +251,4 @@ class QueryGenerationAgent:
                 self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
             # 降级处理:返回原始问题
             return [question]
+

+ 1 - 1
src/api/main.py

@@ -144,7 +144,7 @@ async def generate_queries(request: QuestionRequest):
         task_id = int(time.time() * 1000)
         
         # 创建任务记录,状态设置为0(待执行)
-        task_dao.create_task(task_id, request.question)
+        task_dao.create_task(task_id, request.question, knowledge_type=request.knowledgeType or "内容知识")
         logger.info(f"创建任务: {task_id},状态: 待执行")
         
         # 立即返回待执行状态

+ 6 - 5
src/database/models.py

@@ -69,7 +69,7 @@ class QueryTaskDAO:
     def __init__(self):
         self.db_manager = get_db_manager()
     
-    def create_task(self, task_id: int, question: str) -> bool:
+    def create_task(self, task_id: int, question: str, knowledge_type: str = "") -> bool:
         """
         创建新的查询任务
         
@@ -83,14 +83,15 @@ class QueryTaskDAO:
         try:
             with self.db_manager.get_cursor() as cursor:
                 sql = """
-                INSERT INTO knowledge_suggest_query (task_id, question, status)
-                VALUES (%s, %s, %s)
+                INSERT INTO knowledge_suggest_query (task_id, question, status, knowledgeType)
+                VALUES (%s, %s, %s, %s)
                 ON DUPLICATE KEY UPDATE
                 question = VALUES(question),
                 status = VALUES(status),
-                querys = NULL
+                querys = NULL,
+                knowledgeType = VALUES(knowledgeType)
                 """
-                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING))
+                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledge_type or "内容知识"))
                 return True
         except Exception as e:
             logger.error(f"创建任务失败: {e}")

+ 1 - 0
src/models/schemas.py

@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
 class QuestionRequest(BaseModel):
     """问题请求模型"""
     question: str = Field(..., description="用户提出的问题", min_length=1, max_length=1000)
+    knowledgeType: Optional[str] = Field(default="内容知识", description="知识类型:内容知识/工具知识")
 
 
 class QueryResponse(BaseModel):

+ 41 - 44
src/tools/prompts.py

@@ -1,47 +1,44 @@
 """Prompt模板定义"""
 
-QUERY_GENERATION_PROMPT = """
-你是一个专业的搜索查询词生成助手。你的任务是根据用户的问题,生成多个相关的查询词,这些查询词将用于网络爬虫任务来收集相关信息。
-
-用户问题:{question}
-
-请按照以下要求生成查询词:
-1. 生成5-10个相关的查询词
-2. 查询词应该简洁明了,便于搜索
-3. 包含不同的角度和层次(如:基础概念、具体方法、实际应用等)
-4. 避免过于宽泛或过于狭窄的查询词
-5. 考虑同义词和相关术语
-6. 不要直接复述原问题,要生成具体的搜索关键词
-7. 每个查询词应该是独立的搜索词,而不是完整的问题
-
-输出格式:直接输出查询词,每行一个,不要编号或其他格式,也不需要其他额外的说明,仅输出查询词即可。
-
-示例:
-如果问题是"如何学习Python编程"
-输出可能是:
-Python编程入门
-Python基础教程
-Python学习路径
-Python编程实践
-Python语法学习
-Python开发环境搭建
-Python项目实战
-Python算法实现
-Python Web开发
-Python数据分析
-"""
-
-QUERY_REFINEMENT_PROMPT = """
-你是一个查询词优化助手。请对以下查询词进行评估和优化:
-
-原始问题:{question}
-生成的查询词:{queries}
-
-请:
-1. 评估每个查询词的相关性和有效性
-2. 移除重复或过于相似的查询词
-3. 优化表达不清晰的查询词
-4. 确保查询词覆盖问题的不同方面
-
-输出优化后的查询词列表,每行一个,无需其他说明。
+STRUCTURED_TOOL_DEMAND_PROMPT = """
+你是一个高级AI Agent Prompt工程师,精通内容策略、信息检索与结构化。你的任务是处理一份社交媒体内容创作的“工具需求清单”,旨在为后续全网寻找工具和工具知识、并最终训练大模型精准解构爆款内容奠定基础。
+请你遵循以下两个步骤,对提供的每一个工具需求进行处理:
+第一步:细化需求点
+将每个原始的需求句,精确拆解为以下三个核心要素:
+平台 (Platform): 指明该需求所针对的社交媒体平台(例如“小红书平台”)。
+核心任务 (Core Task): 概括性描述该工具在创作流程中要完成的主要功能或解决的主要问题(例如“选题阶段工具”、“素材收集”、“视觉制作”、“文案创作”、“发布管理”、“数据分析”等)。
+具体目标 (Specific Goal): 详细描述该工具在核心任务下要达成的具体效果或处理的具体内容(例如“分析‘打工人’相关话题的热度和趋势”、“获取高质量的猫咪或其他动物表情包”、“进行抠图、背景替换和文字叠加”、“生成幽默、接地气的‘打工人’文案”等)。
+第二步:分层关键词拓展
+基于第一步细化后的“核心任务”和“具体目标”,生成一个结构化的、分层的中文互联网搜索关键词列表。关键词总数应在7-15个之间。请严格按照以下层级和要求生成,并确保所有关键词的后缀都明确指向工具(例如:“工具”、“软件”、“App”、“网站”、“助手”、“生成器”等),除非关键词本身是一个工具名称或专有名词。
+通用发现类 (General Discovery Queries) - 粗粒度 (2-3个):
+目的: 泛指工具类型或大功能,用于初步探索和广撒网。
+规则: 结合platform和core_task,并附加明确的工具类后缀。
+示例: "小红书选题工具", "图片编辑软件", "AI写作App"
+主题功能类 (Themed Function Queries) - 细粒度 (3-5个):
+目的: 紧密结合specific_goal中的特定主题、对象、效果或限制,形成更精准的工具功能搜索词。
+规则: 必须融入specific_goal中的所有关键信息(如“打工人”、“猫咪表情包”、“幽默与推广”),并附加明确的工具类后缀。
+示例: "小红书打工人话题分析工具", "猫咪表情包下载网站", "结合幽默推广的竞品分析软件"
+操作/解决问题类 (How-to/Problem-Solving Queries) - 方法/手段工具导向 (2-4个):
+目的: 模拟用户遇到问题或需要学习具体操作时的查询,强调实现目标所需的“方法/手段”和对应的“工具”。
+规则: 严格遵循“根据【如何】【目的】的【方法/手段】及使用的【工具】是什么”的思考框架来构建,生成明确指向“方法工具”、“操作工具”、“解决方案工具”的查询词。关键词必须附加明确的工具类后缀。
+示例应用此框架:
+【目的】:分析“打工人”相关话题的热度和趋势
+【如何】:通过数据分析、趋势洞察
+【方法/手段】:数据分析方法,趋势洞察策略
+【工具】:是什么工具?
+生成Query: "如何分析小红书打工人话题热度趋势的工具" / "小红书打工人话题趋势分析方法工具"
+其他示例: "小红书图片抠图做表情包工具", "AI工具生成幽默文案方法", "小红书话题标签优化助手"
+输出格式要求:
+请以一个统一的JSON数组格式输出结果。数组中的每个对象代表一个原始工具需求,并包含以下字段:
+demand_id: 一个唯一的标识符,例如 "D_选题_001", "D_素材_001" 等,其中“选题”、“素材”等对应原始清单中的工具类别。
+original_demand: 原始的需求句。
+tool_category: 原始清单中的工具类别(如“选题工具”,“素材工具”等)。
+decomposed_elements: 一个嵌套对象,包含 platform, core_task, specific_goal。
+expanded_keywords: 一个嵌套对象,包含以下三个数组:
+general_discovery_queries: 通用发现类关键词列表。
+themed_function_queries: 主题功能类关键词列表。
+how_to_use_queries: 操作/解决问题类关键词列表。
+
+请严格只输出JSON数组,不要包含任何额外文本或解释。
 """
+ 

+ 0 - 98
src/tools/query_tool.py

@@ -1,98 +0,0 @@
-from typing import List, Dict, Any
-from langchain.tools import BaseTool
-from pydantic import BaseModel, Field
-import logging
-from langchain_google_genai import ChatGoogleGenerativeAI
-from langchain.prompts import ChatPromptTemplate
-from langchain.schema import HumanMessage, SystemMessage
-
-from ..database.models import QueryTaskDAO, QueryTaskStatus
-from .prompts import QUERY_GENERATION_PROMPT
-
-logger = logging.getLogger(__name__)
-
-
-class QueryToolInput(BaseModel):
-    """查询工具输入模型"""
-    question: str = Field(..., description="用户的问题")
-    context: str = Field(default="", description="额外的上下文信息")
-    task_id: int = Field(..., description="任务ID")
-
-
-class SuggestQueryTool(BaseTool):
-    """建议查询词工具"""
-    
-    name: str = "suggest_query"
-    description: str = "根据用户问题生成多个相关的查询词,用于爬虫任务"
-    args_schema: type = QueryToolInput
-    task_dao: QueryTaskDAO = None
-    llm: ChatGoogleGenerativeAI = None
-    
-    def __init__(self):
-        super().__init__()
-        self.task_dao = QueryTaskDAO()
-        # 初始化 LLM
-        import os
-        self.llm = ChatGoogleGenerativeAI(
-            google_api_key=os.getenv("GEMINI_API_KEY", ""),
-            model=os.getenv("GEMINI_MODEL", "gemini-1.5-pro"),
-            temperature=0.7
-        )
-    
-    def _run(
-        self, 
-        question: str, 
-        context: str = "",
-        task_id: int = 0,
-        run_manager = None,
-        **kwargs: Any,
-    ) -> List[str]:
-        """
-        根据问题生成查询词列表
-        
-        Args:
-            question: 用户问题
-            context: 额外上下文
-            task_id: 任务ID
-            run_manager: 运行管理器
-            
-        Returns:
-            查询词列表
-        """
-        try:
-            # 使用 LLM 和 QUERY_GENERATION_PROMPT 生成查询词
-            prompt = ChatPromptTemplate.from_messages([
-                SystemMessage(content=QUERY_GENERATION_PROMPT),
-                HumanMessage(content=question)
-            ])
-            
-            try:
-                response = self.llm.invoke(prompt.format_messages())
-                queries_text = response.content
-                queries = [q.strip() for q in queries_text.split('\n') if q.strip()]
-                
-                # 去重并限制数量
-                unique_queries = list(dict.fromkeys(queries))[:10]
-                
-                return unique_queries
-                
-            except Exception as e:
-                logger.error(f"LLM 生成查询词失败: {e}")
-                # 降级处理:返回原始问题
-                return [question]
-            
-        except Exception as e:
-            logger.error(f"生成查询词失败: {e}")
-            # 返回降级结果
-            return [question]
-    
-    async def _arun(
-        self, 
-        question: str, 
-        context: str = "",
-        task_id: int = 0,
-        run_manager = None,
-        **kwargs: Any,
-    ) -> List[str]:
-        """异步运行版本"""
-        return self._run(question, context, task_id, run_manager, **kwargs)

+ 2 - 2
src/tools/scheduler.py

@@ -80,8 +80,8 @@ class TaskScheduler:
             self.task_dao.update_task_status(task.task_id, QueryTaskStatus.RUNNING)
             
             try:
-                # 使用Agent生成查询词
-                queries = await self.agent.generate_queries(task.question, task.task_id)
+                # 使用Agent生成查询词,传入knowledgeType
+                queries = await self.agent.generate_queries(task.question, task.task_id, getattr(task, 'knowledgeType', '') or '')
                 
                 # 更新任务结果
                 success = self.task_dao.update_task_results(task.task_id, queries, QueryTaskStatus.SUCCESS)