Browse Source

clean_agent

丁云鹏 1 week ago
parent
commit
ea7906af8e
7 changed files with 215 additions and 131 deletions
  1. 5 1
      .env
  2. 67 49
      agents/clean_agent/agent.py
  3. 114 63
      agents/clean_agent/tools.py
  4. 10 3
      database/db.py
  5. 0 7
      database/models.py
  6. 17 8
      gemini.py
  7. 2 0
      requirements.txt

+ 5 - 1
.env

@@ -14,7 +14,7 @@ COZE_BOT_ID=7537570163895812146
 GEMINI_API_KEY=AIzaSyAkt1l9Kw1CQgHFzTpla0vgt0OE53fr-BI
 
 # 代理
-DYNAMIC_HTTP_PROXY=http://t10952018781111:1ap37oc3@d844.kdltps.com:15818
+DYNAMIC_HTTP_PROXY=http://127.0.0.1:7890
 
 # GRPC
 CONTAINER_GRPC_HOST=192.168.203.112
@@ -34,3 +34,7 @@ LANGCHAIN_API_KEY=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
 # 项目名称(可选)
 LANGCHAIN_PROJECT=knowledge-agent
 
+
+OPENAI_API_KEY=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
+TAVILY_API_KEY=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
+

+ 67 - 49
agents/clean_agent/agent.py

@@ -9,22 +9,20 @@ from tools import evaluation_extraction_tool
 
 from langgraph.prebuilt import ToolNode, tools_condition
 from langgraph.checkpoint.memory import InMemorySaver
+import requests
+from dotenv import load_dotenv
+
+# 加载环境变量
+load_dotenv()
 
 graph=None
 llm_with_tools=None
-os.environ["OPENAI_API_KEY"] = "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"
-os.environ["TAVILY_API_KEY"] = "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"
 
 prompt="""
-你好!我是一个智能数据助手,专为协助您快速获取和分析评估信息而设计。
-
----
-### 我的角色 (Role):
+### 角色 (Role):
 我将充当您的“评估报告检索专员”。当您需要了解特定主题的评估情况时,我将利用背后强大的【评估提取工具】(evaluation_extraction_tool) 来精确地从数据源中检索和整理相关评估报告、摘要或关键指标,并呈现给您。
 
----
-### 您的目标 (Goal):
-您的目标是:
+### 目标 (Goal):
 1.  根据特定的主题(关键词)快速获取相关的评估报告、数据摘要或关键指标,以便您能深入了解某个方面(如产品表现、服务质量、市场反馈、项目评估等)的详细评估情况。
 2.  为您的每次查询提供一个唯一的标识符,以便您能轻松追踪和管理您的请求,确保数据的可追溯性。
 
@@ -45,8 +43,11 @@ prompt="""
 3.  **我返回结果:** 【评估提取工具】执行完毕后,我将把提取到的评估摘要、链接或相关数据返回给您。
 
 ---
-### 请您按照以下格式提供信息:
+### 输入信息:
 {input}
+
+### 输出:
+根据执行结果,输出相应的字符串 ("success", "no data", 或 "其他")
 """
 
 class State(TypedDict):
@@ -68,51 +69,68 @@ def execute_agent_with_api(user_input: str):
     
     # 替换prompt中的{input}占位符为用户输入
     formatted_prompt = prompt.replace("{input}", user_input)
-    
-    # 如果graph或llm_with_tools未初始化,先初始化
-    if graph is None or llm_with_tools is None:
-        llm = init_chat_model("openai:gpt-4.1")
-        tools = [evaluation_extraction_tool]
-        llm_with_tools = llm.bind_tools(tools=tools)
+
+    try:
+        # 如果graph或llm_with_tools未初始化,先初始化
+        if graph is None or llm_with_tools is None:
+            try:
+                llm = init_chat_model("openai:gpt-4.1")
+                tools = [evaluation_extraction_tool]
+                llm_with_tools = llm.bind_tools(tools=tools)
+                
+                # 初始化图
+                graph_builder = StateGraph(State)
+                graph_builder.add_node("chatbot", chatbot)
+                
+                tool_node = ToolNode(tools=tools)
+                graph_builder.add_node("tools", tool_node)
+                
+                graph_builder.add_conditional_edges(
+                    "chatbot",
+                    tools_condition,
+                )
+                graph_builder.add_edge("tools", "chatbot")
+                graph_builder.add_edge(START, "chatbot")
+                
+                memory = InMemorySaver()
+                graph = graph_builder.compile(checkpointer=memory)
+            except Exception as e:
+                return f"初始化Agent失败: {str(e)}"
         
-        # 初始化图
-        graph_builder = StateGraph(State)
-        graph_builder.add_node("chatbot", chatbot)
+        # 生成唯一的线程ID
+        import uuid
+        thread_id = str(uuid.uuid4())
         
-        tool_node = ToolNode(tools=tools)
-        graph_builder.add_node("tools", tool_node)
+        # 执行Agent并收集结果
+        results = []
+        config = {"configurable": {"thread_id": thread_id}}
         
-        graph_builder.add_conditional_edges(
-            "chatbot",
-            tools_condition,
-        )
-        graph_builder.add_edge("tools", "chatbot")
-        graph_builder.add_edge(START, "chatbot")
+        # 使用格式化后的prompt作为用户输入
+        for event in graph.stream({"messages": [{"role": "user", "content": formatted_prompt}]}, config, stream_mode="values"):
+            for value in event.values():
+                # 保存消息内容
+                if "messages" in event and len(event["messages"]) > 0:
+                    message = event["messages"][-1]
+                    results.append(message.content)
         
-        memory = InMemorySaver()
-        graph = graph_builder.compile(checkpointer=memory)
-    
-    # 生成唯一的线程ID
-    import uuid
-    thread_id = str(uuid.uuid4())
-    
-    # 执行Agent并收集结果
-    results = []
-    config = {"configurable": {"thread_id": thread_id}}
-    
-    # 使用格式化后的prompt作为用户输入
-    for event in graph.stream({"messages": [{"role": "user", "content": formatted_prompt}]}, config, stream_mode="values"):
-        for value in event.values():
-            # 保存消息内容
-            if "messages" in event and len(event["messages"]) > 0:
-                message = event["messages"][-1]
-                results.append(message.content)
-    
-    # 返回结果
-    return "\n".join(results) if results else "Agent执行完成,但没有返回结果"
+        # 返回结果
+        return "\n".join(results) if results else "Agent执行完成,但没有返回结果"
+    except requests.exceptions.ConnectionError as e:
+        return f"OpenAI API 连接错误: {str(e)}\n请检查网络连接或代理设置。"
+    except Exception as e:
+        return f"执行Agent时出错: {str(e)}"
 
 def main():
-    execute_agent_with_api("Can you look up when LangGraph was released? When you have the answer, use the human_assistance tool for review.")
+    print(f"开始执行Agent")
+    # 设置代理
+    proxy_url = os.getenv('DYNAMIC_HTTP_PROXY')
+    if proxy_url:
+        os.environ["OPENAI_PROXY"] = proxy_url
+        os.environ["HTTPS_PROXY"] = proxy_url
+        os.environ["HTTP_PROXY"] = proxy_url
+    # 执行Agent
+    result = execute_agent_with_api('{"query_word":"账号做内容的方法","request_id":"REQUEST_001"}')
+    print(result)
 
 if __name__ == '__main__':
     main()

+ 114 - 63
agents/clean_agent/tools.py

@@ -1,4 +1,4 @@
-from langchain.tools import Tool
+from langchain.tools import tool
 from sqlalchemy.orm import Session
 from typing import Dict, Any, Tuple
 import logging
@@ -6,6 +6,7 @@ from datetime import datetime
 import json
 import os
 import sys
+import re
 
 # 添加项目根目录到系统路径
 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
@@ -23,7 +24,13 @@ BATCH_SIZE = 10  # 分批处理大小
 SCORE_THRESHOLD = 70  # 评分阈值
 
 # Define tools
-@Tool
+# evaluation_extraction_tool = Tool(
+#     func=lambda request_id, query_word: _evaluation_extraction_tool(request_id, query_word),
+#     name="evaluation_extraction_tool",
+#     description="知识评估与抽取工具,用于处理数据库中的数据,执行评估并抽取内容"
+# )
+
+@tool
 def evaluation_extraction_tool(request_id: str, query_word: str) -> str:
     """
     知识评估与抽取工具。持续处理数据库中的数据,分批执行评估并创建KnowledgeExtractionContent对象。
@@ -36,47 +43,79 @@ def evaluation_extraction_tool(request_id: str, query_word: str) -> str:
     Returns:
         str: "success" 表示处理完成,"no data" 表示没有数据需要处理
     """
-    try:
-        db = SessionLocal()
+    # 使用上下文管理器自动管理数据库连接的生命周期
+    with SessionLocal() as db:
         try:
             # 使用新的批量处理函数
             result = execute_continuous_evaluation_extraction(request_id, db, query_word)
             return result
-        finally:
-            db.close()
-    except Exception as e:
-        logger.error(f"评估抽取过程中出错: {e}")
-        return f"no data - 错误: {str(e)}"
+        except Exception as e:
+            # 确保发生异常时回滚事务
+            db.rollback()
+            logger.error(f"评估抽取过程中出错: {e}")
+            return f"no data - 错误: {str(e)}"
 
 def execute_continuous_evaluation_extraction(request_id: str, db: Session, query_word: str) -> str:
     """持续执行评估循环,直到数据库没有数据"""
+    logger.info(f"开始处理,request_id: {request_id}, query_word: {query_word}")
+    
     total_processed = 0
+    offset = 0
     
-    while True:
-        # 分批获取待评估的内容
-        contents = get_batch_contents_for_evaluation(request_id, db, BATCH_SIZE)
-        
-        if not contents:
-            if total_processed > 0:
-                logger.info(f"处理完成,共处理 {total_processed} 条内容")
-                return "success"
-            return "no data"
-        
-        # 批量评估内容并创建KnowledgeExtractionContent对象
-        evaluation_results = batch_evaluate_content(contents, db, request_id, query_word)
-        
-        # 对评分大于阈值的内容进行抽取
-        high_score_results = [result for result in evaluation_results if result["score"] >= SCORE_THRESHOLD]
-        if high_score_results:
-            logger.info(f"发现 {len(high_score_results)} 条高分内容,进行抽取")
-            batch_extract_and_save_content(high_score_results, db, request_id, query_word)
-        
-        total_processed += len(contents)
-        db.commit()
+    try:
+        while True:
+            # 分批获取待评估的内容,使用offset实现分页
+            contents = get_batch_contents_for_evaluation(request_id, db, BATCH_SIZE, offset)
+            
+            logger.info(f"获取到 {len(contents)} 条待评估内容")
+
+            if not contents:
+                if total_processed > 0:
+                    logger.info(f"处理完成,共处理 {total_processed} 条内容")
+                    db.commit()  # 确保最后一批数据被提交
+                    return "success"
+                return "no data"
+            
+            try:
+                # 批量评估内容并创建KnowledgeExtractionContent对象
+                evaluation_results = batch_evaluate_content(contents, db, request_id, query_word)
+                
+                print(f"""evaluation_results: {evaluation_results}""")
+
+                # 对评分大于阈值的内容进行抽取
+                high_score_results = [result for result in evaluation_results if result["score"] >= SCORE_THRESHOLD]
+                if high_score_results:
+                    logger.info(f"发现 {len(high_score_results)} 条高分内容,进行抽取")
+                    batch_extract_and_save_content(high_score_results, db, request_id, query_word)
+                
+                total_processed += len(contents)
+                offset += len(contents)  # 更新offset值,以便下次获取下一批数据
+                db.commit()  # 每批次处理完成后提交事务
+            except Exception as e:
+                # 当前批次处理失败时回滚事务
+                db.rollback()
+                logger.error(f"处理批次数据时出错: {e}")
+                # 继续处理下一批数据
+                offset += len(contents)
+    except Exception as e:
+        # 发生严重异常时回滚事务并抛出异常
+        db.rollback()
+        logger.error(f"执行评估抽取循环时出错: {e}")
+        raise
     # 这里的代码永远不会被执行到,因为在while循环中,当contents为空时会返回
 
-def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size: int) -> list:
-    """分批获取待评估的内容"""
+def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size: int, offset: int = 0) -> list:
+    """分批获取待评估的内容
+    
+    Args:
+        request_id: 请求ID
+        db: 数据库会话
+        batch_size: 批量大小
+        offset: 偏移量,用于分页
+        
+    Returns:
+        待评估内容列表
+    """
     query = db.query(KnowledgeParsingContent).filter(
         KnowledgeParsingContent.status == 2  # 已完成提取的数据
     )
@@ -85,7 +124,7 @@ def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size:
     if request_id:
         query = query.filter(KnowledgeParsingContent.request_id == request_id)
     
-    return query.limit(batch_size).all()
+    return query.offset(offset).limit(batch_size).all()
 
 def batch_evaluate_content(contents: list, db: Session, request_id: str, query_word: str) -> list:
     if not contents:
@@ -95,6 +134,7 @@ def batch_evaluate_content(contents: list, db: Session, request_id: str, query_w
         # 批量调用大模型进行评估
         evaluation_results_raw = batch_call_llm_for_evaluation(contents, query_word)
         
+        print(evaluation_results_raw)
         # 处理评估结果
         evaluation_results = []
         
@@ -130,33 +170,42 @@ def batch_extract_and_save_content(evaluation_results: list, db: Session, reques
     if not evaluation_results:
         return []
     
-    # 批量调用大模型进行抽取
-    extraction_data_list = batch_call_llm_for_extraction(evaluation_results, query_word)
-    
-    # 保存抽取结果到数据库
-    success_ids = []
-    failed_ids = []
-    
-    for i, extraction_data in enumerate(extraction_data_list):
-        try:
-            evaluation_result = evaluation_results[i]
-            
-            # 更新已有对象的data字段和状态
-            existing_extraction.data = evaluation_result["extraction_content"]
-            existing_extraction.status = 2  # 处理完成
-            success_ids.append(parsing_id)
-        except Exception as e:
-            logger.error(f"处理抽取结果 {i} 时出错: {e}")
-            failed_ids.append(evaluation_results[i].get("parsing_id"))
-    
-    # 如果有失败的内容,将其标记为处理失败
-    if failed_ids:
-        logger.warning(f"有 {len(failed_ids)} 条内容抽取失败")
-        for result in evaluation_results:
-            if result.get("parsing_id") in failed_ids and "extraction_content" in result:
-                result["extraction_content"].status = 3  # 处理失败
-    
-    return success_ids
+    try:
+        # 批量调用大模型进行抽取
+        extraction_data_list = batch_call_llm_for_extraction(evaluation_results, query_word)
+        
+        # 保存抽取结果到数据库
+        success_ids = []
+        failed_ids = []
+        
+        for i, extraction_data in enumerate(extraction_data_list):
+            try:
+                evaluation_result = evaluation_results[i]
+                parsing_id = evaluation_result.get("parsing_id")
+                
+                if "extraction_content" in evaluation_result and parsing_id:
+                    # 更新已有对象的data字段和状态
+                    extraction_content = evaluation_result["extraction_content"]
+                    extraction_content.data = extraction_data
+                    extraction_content.status = 2  # 处理完成
+                    success_ids.append(parsing_id)
+            except Exception as e:
+                logger.error(f"处理抽取结果 {i} 时出错: {e}")
+                if i < len(evaluation_results):
+                    failed_ids.append(evaluation_results[i].get("parsing_id"))
+        
+        # 如果有失败的内容,将其标记为处理失败
+        if failed_ids:
+            logger.warning(f"有 {len(failed_ids)} 条内容抽取失败")
+            for result in evaluation_results:
+                if result.get("parsing_id") in failed_ids and "extraction_content" in result:
+                    result["extraction_content"].status = 3  # 处理失败
+        
+        return success_ids
+    except Exception as e:
+        logger.error(f"批量抽取和保存内容时出错: {e}")
+        db.rollback()  # 确保发生异常时回滚事务
+        return []
 
 # 读取提示词文件
 def read_prompt_file(file_path):
@@ -178,9 +227,9 @@ extraction_prompt_path = os.path.join(project_root, 'prompt', 'extraction.md')
 
 # 打印路径信息,用于调试
 logger.info(f"评估提示词路径: {evaluation_prompt_path}")
-logger.info(f"抽取提示词路径: {extraction_prompt_path}")
-
 EVALUATION_PROMPT = read_prompt_file(evaluation_prompt_path)
+
+logger.info(f"抽取提示词路径: {extraction_prompt_path}")
 EXTRACTION_PROMPT = read_prompt_file(extraction_prompt_path)
 
 def batch_call_llm_for_evaluation(contents: list, query_word: str) -> list:
@@ -197,10 +246,11 @@ def batch_call_llm_for_evaluation(contents: list, query_word: str) -> list:
     try:
         # 批量调用 Gemini 进行评估
         results = gemini_processor.batch_process(evaluation_contents, EVALUATION_PROMPT)
-        
+       
         # 处理返回结果
         evaluation_results = []
         for i, result in enumerate(results):
+            result = re.sub(r'^\s*```json|\s*```\s*$', '', result, flags=re.MULTILINE).strip()
             parsing_id = contents[i].id
             parsing_data = contents[i].parsing_data
             
@@ -239,6 +289,7 @@ def batch_call_llm_for_extraction(evaluation_results: list, query_word: str) ->
         # 处理返回结果
         extraction_results = []
         for i, result in enumerate(results):
+            result = re.sub(r'^\s*```json|\s*```\s*$', '', result, flags=re.MULTILINE).strip()
             # 确保结果包含必要的字段
             if not isinstance(result, dict):
                 result = {"extracted_data": str(result)}

+ 10 - 3
database/db.py

@@ -10,7 +10,7 @@ load_dotenv()
 # 数据库连接配置
 DATABASE_URL = os.getenv(
     "DATABASE_URL", 
-    "mysql+pymysql://wqsd:wqsd@2025@rm-bp13g3ra2f59q49xs.mysql.rds.aliyuncs.com:3306/ai_knowledge?charset=utf8&connect_timeout=30&read_timeout=30&write_timeout=30"
+    "mysql+pymysql://wqsd:wqsd%402025@knowledge.rwlb.rds.aliyuncs.com:3306/ai_knowledge?charset=utf8&connect_timeout=60&read_timeout=300&write_timeout=300"
 )
 
 # 创建同步引擎和会话
@@ -18,8 +18,15 @@ engine = create_engine(
     DATABASE_URL,
     pool_size=10,
     max_overflow=20,
-    pool_timeout=30,
-    echo=False  # 设为True可查看SQL日志
+    pool_timeout=60,  # 增加池连接超时时间
+    pool_recycle=1800,  # 减少连接回收时间,单位为秒,防止连接长时间占用
+    pool_pre_ping=True,  # 添加连接前ping测试,确保连接有效
+    echo=False,  # 设为True可查看SQL日志
+    connect_args={
+        "connect_timeout": 60,  # 连接超时时间
+        "read_timeout": 300,  # 读取超时时间
+        "write_timeout": 300,  # 写入超时时间
+    }
 )
 
 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

+ 0 - 7
database/models.py

@@ -5,9 +5,6 @@ from .db import Base
 
 class KnowledgeParsingContent(Base):
     __tablename__ = 'knowledge_parsing_content'
-    __table_args__ = {
-        'comment': '内容解析表'
-    }
     
     id = Column(BigInteger, primary_key=True, autoincrement=True)
     content_id = Column(String(128), nullable=False)
@@ -23,10 +20,6 @@ class KnowledgeParsingContent(Base):
 
 class KnowledgeExtractionContent(Base):
     __tablename__ = 'knowledge_extraction_content'
-    __table_args__ = (
-        Index('idx_request_id', 'request_id'),  # 创建索引
-        {'comment': '内容抽取表'}
-    )
     
     id = Column(BigInteger, primary_key=True, autoincrement=True)
     request_id = Column(String(128), nullable=False)

+ 17 - 8
gemini.py

@@ -26,24 +26,32 @@ class GeminiProcessor:
         
         # 配置Gemini
         genai.configure(api_key=self.api_key)
-        self.model = genai.GenerativeModel('gemini-2.5-flash')
+
     
     def process(self, content: Any, system_prompt: str) -> Dict[str, Any]:
 
         try:
-            # 构建完整的提示词
-            full_prompt = f"{system_prompt}\n\n内容:{json.dumps(content, ensure_ascii=False)}"
+            # 处理输入内容格式
+            if isinstance(content, dict):
+                # 将字典转换为JSON字符串
+                formatted_content = json.dumps(content, ensure_ascii=False)
+            else:
+                formatted_content = content
+                
+            # 创建带有 system_instruction 的模型实例
+            model_with_system = genai.GenerativeModel(
+                'gemini-2.5-flash',
+                system_instruction=system_prompt
+            )
             
             # 调用 Gemini API
-            response = self.model.generate_content(
-                contents=content,
-                config=types.GenerateContentConfig(
-                    system_instruction=system_prompt
-                )
+            response = model_with_system.generate_content(
+                contents=formatted_content
             )
             
             # 尝试解析 JSON 响应
             try:
+                return response.text;
                 result = json.loads(response.text)
                 return result
             except json.JSONDecodeError:
@@ -51,6 +59,7 @@ class GeminiProcessor:
                 return {"result": response.text, "raw_response": response.text}
                 
         except Exception as e:
+            print(f"Gemini API 调用失败: {e}")
             return {"error": str(e), "content": content} 
             
     def batch_process(self, contents: list, system_prompt: str) -> list:

+ 2 - 0
requirements.txt

@@ -13,3 +13,5 @@ uvicorn[standard]>=0.35.0
 langgraph==0.6.6
 langsmith==0.4.16
 langchain-openai==0.3.31
+
+google-generativeai==0.8.5