123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- from langchain_core.tools import tool
- from sqlalchemy.orm import Session
- from datetime import datetime
- import json
- import os
- import sys
- import re
- # 添加项目根目录到系统路径
- sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
- from gemini import GeminiProcessor
- from database.db import SessionLocal, get_db
- from database.models import KnowledgeParsingContent, KnowledgeExtractionContent
- from utils.logging_config import get_logger
- # 配置日志
- logger = get_logger('CleanTools')
- # 配置常量
- BATCH_SIZE = 5 # 分批处理大小
- SCORE_THRESHOLD = 70 # 评分阈值
- # 初始化 Gemini 处理器和提示词
- processor = GeminiProcessor()
- # 加载评估和抽取提示词
- project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- consistency_prompt_path = os.path.join(project_root, 'prompt', 'consistency.md')
- evaluation_prompt_path = os.path.join(project_root, 'prompt', 'evaluation.md')
- extraction_prompt_path = os.path.join(project_root, 'prompt', 'extraction.md')
- def read_prompt_file(file_path):
- """从文件中读取提示词"""
- try:
- with open(file_path, 'r', encoding='utf-8') as file:
- return file.read()
- except Exception as e:
- logger.error(f"读取提示词文件 {file_path} 失败: {str(e)}")
- return ""
- CONSISTENCY_PROMPT = read_prompt_file(consistency_prompt_path)
- EVALUATION_PROMPT = read_prompt_file(evaluation_prompt_path)
- EXTRACTION_PROMPT = read_prompt_file(extraction_prompt_path)
- @tool
- def evaluation_extraction_tool(request_id: str, query_word: str) -> str:
- """
- 知识评估与抽取工具。持续处理数据库中的数据,分批执行评估并创建KnowledgeExtractionContent对象。
- 对于评分大于70分的内容,会进行抽取并更新KnowledgeExtractionContent对象。
-
- Args:
- request_id: 请求ID,如果不提供则处理所有未处理的数据
- query_word: 查询词,用于评估和抽取内容
-
- Returns:
- str: "success" 表示处理完成,"no data" 表示没有数据需要处理
- """
- # 使用上下文管理器自动管理数据库连接的生命周期
- with SessionLocal() as db:
- try:
- # 使用新的批量处理函数
- result = execute_continuous_evaluation_extraction(request_id, db, query_word)
- return result
- except Exception as e:
- # 确保发生异常时回滚事务
- db.rollback()
- logger.error(f"评估抽取过程中出错: {e}")
- return f"no data - 错误: {str(e)}"
- def evaluation_extraction(request_id: str, query_word: str) -> str:
- """
- 知识评估与抽取工具。持续处理数据库中的数据,分批执行评估并创建KnowledgeExtractionContent对象。
- 对于评分大于70分的内容,会进行抽取并更新KnowledgeExtractionContent对象。
-
- Args:
- request_id: 请求ID,如果不提供则处理所有未处理的数据
- query_word: 查询词,用于评估和抽取内容
-
- Returns:
- str: "success" 表示处理完成,"no data" 表示没有数据需要处理
- """
- # 使用上下文管理器自动管理数据库连接的生命周期
- with SessionLocal() as db:
- try:
- # 使用新的批量处理函数
- result = execute_continuous_evaluation_extraction(request_id, db, query_word)
- return result
- except Exception as e:
- # 确保发生异常时回滚事务
- db.rollback()
- logger.error(f"评估抽取过程中出错: {e}")
- return f"no data - 错误: {str(e)}"
- def execute_continuous_evaluation_extraction(request_id: str, db: Session, query_word: str) -> str:
- """持续执行评估循环,直到数据库没有数据"""
- logger.info(f"开始处理,request_id: {request_id}, query_word: {query_word}")
-
- try:
- while True:
- # 分批获取待评估的内容
- contents = get_batch_contents_for_evaluation(request_id, db, BATCH_SIZE)
-
- logger.info(f"获取到 {len(contents)} 条待评估内容")
- if not contents:
- return "no data"
-
- try:
- for content in contents:
- # 一致性评估
- logger.info(f"正在进行一致性评估:{content.id}")
- consistency_result, reason_str = evaluate_consistency(query_word, content.parsing_data)
-
- extraction_content = KnowledgeExtractionContent(
- request_id=request_id,
- parsing_id=content.id,
- consistency=consistency_result,
- consistency_reason=reason_str,
- create_at=datetime.now()
- )
-
- if consistency_result == '高':
- # 内容评估
- logger.info(f"正在进行内容评估:{content.id}")
- score, score_reason = evaluate_content(query_word, content.parsing_data)
- extraction_content.score = score
- extraction_content.score_reason = score_reason
- if score >= SCORE_THRESHOLD:
- # 清洗提取
- logger.info(f"正在进行清洗提取:{content.id}")
- extracted_data, clean_reason = extract_content(query_word, content.parsing_data)
- extraction_content.data = extracted_data
- extraction_content.clean_reason = clean_reason
- extraction_content.status = 2
- db.add(extraction_content)
- db.commit() # 每批次处理完成后提交事务
- except Exception as e:
- # 当前批次处理失败时回滚事务
- db.rollback()
- logger.error(f"处理批次数据时出错: {e}")
- except Exception as e:
- # 发生严重异常时回滚事务并抛出异常
- db.rollback()
- logger.error(f"执行评估抽取循环时出错: {e}")
- raise
- def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size: int) -> list:
- query = db.query(KnowledgeParsingContent).outerjoin(
- KnowledgeExtractionContent,
- KnowledgeParsingContent.id == KnowledgeExtractionContent.parsing_id
- ).filter(
- KnowledgeParsingContent.status == 5, # 已完成提取的数据
- KnowledgeParsingContent.request_id == request_id,
- KnowledgeExtractionContent.parsing_id == None
- )
-
- return query.limit(batch_size).all()
- def evaluate_consistency(keyword, structured_result):
- """评估一致性"""
- try:
- input_data = {
- "query": keyword,
- "query结果文本": structured_result
- }
-
- # 调用LLM进行一致性评估
- result = processor.process(input_data, CONSISTENCY_PROMPT)
-
- try:
- # 尝试解析JSON结果
- # 处理可能的不完整JSON字符串
- result = result.strip()
- if result.startswith('```json') and '```' in result:
- # 提取JSON部分
- json_str = result.split('```json', 1)[1].split('```', 1)[0].strip()
- json_result = json.loads(json_str)
- else:
- json_result = json.loads(result)
-
- consistency = json_result.get("consistency", "")
- reason = json_result.get("reason", [])
- reason_str = "\n".join(reason) if isinstance(reason, list) else str(reason)
- return consistency, reason_str
- except json.JSONDecodeError as e:
- # 如果结果不是有效的JSON,尝试修复并重新解析
- logger.warning(f"一致性评估结果解析失败: {result[:200]}... 错误: {e}")
- try:
- # 尝试从文本中提取JSON部分
- if '{' in result and '}' in result:
- json_part = result[result.find('{'):result.rfind('}')+1]
- json_result = json.loads(json_part)
- consistency = json_result.get("consistency", "")
- reason = json_result.get("reason", [])
- reason_str = "\n".join(reason) if isinstance(reason, list) else str(reason)
- logger.info(f"修复后解析成功,一致性评估结果: {consistency}")
- return consistency, reason_str
- except:
- pass
- return "解析错误", result[:500] # 限制返回长度
- except Exception as e:
- logger.error(f"一致性评估过程中发生异常: {e}")
- return "评估异常", str(e)
- def evaluate_content(keyword, structured_result):
- try:
- input_data = {
- "query_word": keyword,
- "content": structured_result
- }
- # 批量调用 Gemini 进行评估
- result = processor.process(input_data, EVALUATION_PROMPT)
- try:
- # 只处理大括号外面的内容,保留JSON内部格式
- result = re.sub(r'(^\s*```json)|(\s*```\s*$)', '', result, flags=re.MULTILINE).strip()
- # 尝试修复常见的JSON格式问题
- result = result.replace("'", "\"") # 将单引号替换为双引号
- result = re.sub(r'([{,])\s*(\w+)\s*:', r'\1"\2":', result) # 确保属性名有双引号
-
- # 解析JSON
- parsed_result = json.loads(result)
- score = parsed_result.get("score", -2)
- score_reason = parsed_result.get("reason", "")
- return score, score_reason
- except Exception as json_error:
- logger.error(f"评估JSON解析错误: {str(json_error)},原始内容: {result[:100]}...")
- return -1, result[:500] # 限制返回长度
- except Exception as e:
- # 返回默认结果
- return -1, f"评估过程异常: {str(e)}"
- def extract_content(keyword, structured_result):
- try:
- input_data = {
- "query_word": keyword,
- "content": structured_result
- }
- # 批量调用 Gemini 进行评估
- result = processor.process(input_data, EXTRACTION_PROMPT)
- try:
- # 只处理大括号外面的内容,保留JSON内部格式
- result = re.sub(r'(^\s*```json)|(\s*```\s*$)', '', result, flags=re.MULTILINE).strip()
- # 尝试修复常见的JSON格式问题
- result = result.replace("'", "\"") # 将单引号替换为双引号
- result = re.sub(r'([{,])\s*(\w+)\s*:', r'\1"\2":', result) # 确保属性名有双引号
-
- # 解析JSON
- parsed_result = json.loads(result)
- extracted_data = parsed_result.get("extracted_content", "未提取到内容")
- clean_reason = parsed_result.get("analysis_reason", "未返回原因")
- return extracted_data, clean_reason
- except Exception as json_error:
- logger.error(f"JSON解析错误: {str(json_error)}")
- extracted_data = "未提取到内容"
- clean_reason = f"JSON解析错误: {str(json_error)}"
- return "", clean_reason
- except Exception as e:
- # 返回默认结果
- return "", f"提取过程异常: {str(e)}"
|