Browse Source

clean_agent

丁云鹏 1 week ago
parent
commit
de6efa2e33
1 changed files with 241 additions and 55 deletions
  1. 241 55
      agents/clean_agent/tools.py

+ 241 - 55
agents/clean_agent/tools.py

@@ -1,67 +1,253 @@
-from langchain_core.tools import tool
-from typing import Annotated
-from langchain_core.messages import ToolMessage
-from langchain_core.tools import InjectedToolCallId, tool
+from langchain.tools import Tool
+from sqlalchemy.orm import Session
+from typing import Dict, Any, Tuple
+import logging
+from datetime import datetime
+import json
+import os
+import sys
 
-from langgraph.types import Command, interrupt
+# 添加项目根目录到系统路径
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
 
-# Define tools
-@tool
-def multiply(a: int, b: int) -> int:
-    """Multiply a and b.
+from database.db import SessionLocal, get_db
+from database.models import KnowledgeParsingContent, KnowledgeExtractionContent
+from gemini import GeminiProcessor
+
+# 配置日志
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
 
+# 配置常量
+BATCH_SIZE = 10  # 分批处理大小
+SCORE_THRESHOLD = 70  # 评分阈值
+
+# Define tools
+@Tool
+def evaluation_extraction_tool(request_id: str, query_word: str) -> str:
+    """
+    知识评估与抽取工具。持续处理数据库中的数据,分批执行评估并创建KnowledgeExtractionContent对象。
+    对于评分大于70分的内容,会进行抽取并更新KnowledgeExtractionContent对象。
+    
     Args:
-        a: first int
-        b: second int
+        request_id: 请求ID,如果不提供则处理所有未处理的数据
+        query_word: 查询词,用于评估和抽取内容
+        
+    Returns:
+        str: "success" 表示处理完成,"no data" 表示没有数据需要处理
     """
-    return a * b
+    try:
+        db = SessionLocal()
+        try:
+            # 使用新的批量处理函数
+            result = execute_continuous_evaluation_extraction(request_id, db, query_word)
+            return result
+        finally:
+            db.close()
+    except Exception as e:
+        logger.error(f"评估抽取过程中出错: {e}")
+        return f"no data - 错误: {str(e)}"
 
+def execute_continuous_evaluation_extraction(request_id: str, db: Session, query_word: str) -> str:
+    """持续执行评估循环,直到数据库没有数据"""
+    total_processed = 0
+    
+    while True:
+        # 分批获取待评估的内容
+        contents = get_batch_contents_for_evaluation(request_id, db, BATCH_SIZE)
+        
+        if not contents:
+            if total_processed > 0:
+                logger.info(f"处理完成,共处理 {total_processed} 条内容")
+                return "success"
+            return "no data"
+        
+        # 批量评估内容并创建KnowledgeExtractionContent对象
+        evaluation_results = batch_evaluate_content(contents, db, request_id, query_word)
+        
+        # 对评分大于阈值的内容进行抽取
+        high_score_results = [result for result in evaluation_results if result["score"] >= SCORE_THRESHOLD]
+        if high_score_results:
+            logger.info(f"发现 {len(high_score_results)} 条高分内容,进行抽取")
+            batch_extract_and_save_content(high_score_results, db, request_id, query_word)
+        
+        total_processed += len(contents)
+        db.commit()
+    # 这里的代码永远不会被执行到,因为在while循环中,当contents为空时会返回
 
-@tool
-def add(a: int, b: int) -> int:
-    """Adds a and b.
+def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size: int) -> list:
+    """分批获取待评估的内容"""
+    query = db.query(KnowledgeParsingContent).filter(
+        KnowledgeParsingContent.status == 2  # 已完成提取的数据
+    )
+    
+    # 如果指定了request_id,则只处理该request_id的数据
+    if request_id:
+        query = query.filter(KnowledgeParsingContent.request_id == request_id)
+    
+    return query.limit(batch_size).all()
 
-    Args:
-        a: first int
-        b: second int
-    """
-    return a + b
+def batch_evaluate_content(contents: list, db: Session, request_id: str, query_word: str) -> list:
+    if not contents:
+        return []
+    
+    try:
+        # 批量调用大模型进行评估
+        evaluation_results_raw = batch_call_llm_for_evaluation(contents, query_word)
+        
+        # 处理评估结果
+        evaluation_results = []
+        
+        for i, (parsing_id, score, reason, parsing_data) in enumerate(evaluation_results_raw):
+            # 创建KnowledgeExtractionContent对象
+            extraction_content = KnowledgeExtractionContent(
+                request_id=request_id,
+                parsing_id=parsing_id,
+                score=score,
+                reason=reason,
+                create_at=datetime.now()
+            )
+            db.add(extraction_content)
+            
+            evaluation_results.append({
+                "parsing_id": parsing_id,
+                "score": score,
+                "reason": reason,
+                "parsing_data": parsing_data,
+                "extraction_content": extraction_content
+            })
+            
+        return evaluation_results
+            
+    except Exception as e:
+        logger.error(f"批量评估内容时出错: {e}")
+        # 将所有内容标记为处理失败
+        for content in contents:
+            content.status = 3  # 处理失败
+        return []
 
+def batch_extract_and_save_content(evaluation_results: list, db: Session, request_id: str, query_word: str) -> list:
+    if not evaluation_results:
+        return []
+    
+    # 批量调用大模型进行抽取
+    extraction_data_list = batch_call_llm_for_extraction(evaluation_results, query_word)
+    
+    # 保存抽取结果到数据库
+    success_ids = []
+    failed_ids = []
+    
+    for i, extraction_data in enumerate(extraction_data_list):
+        try:
+            evaluation_result = evaluation_results[i]
+            
+            # 更新已有对象的data字段和状态
+            existing_extraction.data = evaluation_result["extraction_content"]
+            existing_extraction.status = 2  # 处理完成
+            success_ids.append(parsing_id)
+        except Exception as e:
+            logger.error(f"处理抽取结果 {i} 时出错: {e}")
+            failed_ids.append(evaluation_results[i].get("parsing_id"))
+    
+    # 如果有失败的内容,将其标记为处理失败
+    if failed_ids:
+        logger.warning(f"有 {len(failed_ids)} 条内容抽取失败")
+        for result in evaluation_results:
+            if result.get("parsing_id") in failed_ids and "extraction_content" in result:
+                result["extraction_content"].status = 3  # 处理失败
+    
+    return success_ids
 
-@tool
-def divide(a: int, b: int) -> float:
-    """Divide a and b.
+# 读取提示词文件
+def read_prompt_file(file_path):
+    """从文件中读取提示词"""
+    try:
+        with open(file_path, 'r', encoding='utf-8') as file:
+            return file.read()
+    except Exception as e:
+        logger.error(f"读取提示词文件 {file_path} 失败: {str(e)}")
+        return ""
 
-    Args:
-        a: first int
-        b: second int
+# 初始化 Gemini 处理器和提示词
+gemini_processor = GeminiProcessor()
+
+# 加载评估和抽取提示词
+project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+evaluation_prompt_path = os.path.join(project_root, 'prompt', 'evaluation.md')
+extraction_prompt_path = os.path.join(project_root, 'prompt', 'extraction.md')
+
+# 打印路径信息,用于调试
+logger.info(f"评估提示词路径: {evaluation_prompt_path}")
+logger.info(f"抽取提示词路径: {extraction_prompt_path}")
+
+EVALUATION_PROMPT = read_prompt_file(evaluation_prompt_path)
+EXTRACTION_PROMPT = read_prompt_file(extraction_prompt_path)
+
+def batch_call_llm_for_evaluation(contents: list, query_word: str) -> list:
+    """批量调用大模型进行内容评估
     """
-    return a / b
-
-@tool
-def human_assistance(
-    name: str, birthday: str, tool_call_id: Annotated[str, InjectedToolCallId]
-) -> str:
-    """Request assistance from a human."""
-    human_response = interrupt(
-        {
-            "question": "Is this correct?",
-            "name": name,
-            "birthday": birthday,
-        },
-    )
-    if human_response.get("correct", "").lower().startswith("y"):
-        verified_name = name
-        verified_birthday = birthday
-        response = "Correct"
-    else:
-        verified_name = human_response.get("name", name)
-        verified_birthday = human_response.get("birthday", birthday)
-        response = f"Made a correction: {human_response}"
-
-    state_update = {
-        "name": verified_name,
-        "birthday": verified_birthday,
-        "messages": [ToolMessage(response, tool_call_id=tool_call_id)],
-    }
-    return Command(update=state_update)
+    # 准备批量评估内容
+    evaluation_contents = []
+    for content in contents:
+        evaluation_contents.append({
+            "query_word": query_word,
+            "content": content.parsing_data
+        })
+    
+    try:
+        # 批量调用 Gemini 进行评估
+        results = gemini_processor.batch_process(evaluation_contents, EVALUATION_PROMPT)
+        
+        # 处理返回结果
+        evaluation_results = []
+        for i, result in enumerate(results):
+            parsing_id = contents[i].id
+            parsing_data = contents[i].parsing_data
+            
+            if isinstance(result, dict) and "score" in result:
+                # 正常结果
+                score = result.get("score", -2)
+                reason = result.get("reason", "")
+            else:
+                # 异常结果
+                score = -2
+                reason = "评估失败"
+            
+            evaluation_results.append((parsing_id, score, reason, parsing_data))
+        
+        return evaluation_results
+        
+    except Exception as e:
+        logger.error(f"批量评估过程异常: {str(e)}")
+        # 返回默认结果
+        return [(content.id, 0, "评估过程异常", content.data if hasattr(content, 'data') else (content.parsing_data or "")) for content in contents]
+
+def batch_call_llm_for_extraction(evaluation_results: list, query_word: str) -> list:
+    # 准备批量抽取内容
+    extraction_contents = []
+    for result in evaluation_results:
+        parsing_data = result.get("parsing_data", "")
+        extraction_contents.append({
+            "query_word": query_word,
+            "content": parsing_data
+        })
+    
+    try:
+        # 批量调用 Gemini 进行抽取
+        results = gemini_processor.batch_process(extraction_contents, EXTRACTION_PROMPT)
+        
+        # 处理返回结果
+        extraction_results = []
+        for i, result in enumerate(results):
+            # 确保结果包含必要的字段
+            if not isinstance(result, dict):
+                result = {"extracted_data": str(result)}
+
+            extraction_results.append(json.dumps(result, ensure_ascii=False))
+        
+        return extraction_results
+        
+    except Exception as e:
+        logger.error(f"批量抽取过程异常: {str(e)}")
+        # 返回空结果
+        return ["{}"] * len(evaluation_results)