|
@@ -1,21 +1,17 @@
|
|
|
from langchain_core.tools import tool
|
|
|
from sqlalchemy.orm import Session
|
|
|
-from typing import Dict, Any, Tuple
|
|
|
from datetime import datetime
|
|
|
import json
|
|
|
import os
|
|
|
import sys
|
|
|
import re
|
|
|
-import traceback
|
|
|
-from openai import OpenAI
|
|
|
-from gemini import GeminiProcessor
|
|
|
|
|
|
# 添加项目根目录到系统路径
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
|
|
+from gemini import GeminiProcessor
|
|
|
from database.db import SessionLocal, get_db
|
|
|
from database.models import KnowledgeParsingContent, KnowledgeExtractionContent
|
|
|
-from llm.deepseek import DeepSeekProcessor
|
|
|
from utils.logging_config import get_logger
|
|
|
|
|
|
# 配置日志
|
|
@@ -25,6 +21,28 @@ logger = get_logger('CleanTools')
|
|
|
BATCH_SIZE = 5 # 分批处理大小
|
|
|
SCORE_THRESHOLD = 70 # 评分阈值
|
|
|
|
|
|
+# 初始化 Gemini 处理器和提示词
|
|
|
+processor = GeminiProcessor()
|
|
|
+
|
|
|
+
|
|
|
+# 加载评估和抽取提示词
|
|
|
+project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
+consistency_prompt_path = os.path.join(project_root, 'prompt', 'consistency.md')
|
|
|
+evaluation_prompt_path = os.path.join(project_root, 'prompt', 'evaluation.md')
|
|
|
+extraction_prompt_path = os.path.join(project_root, 'prompt', 'extraction.md')
|
|
|
+
|
|
|
+def read_prompt_file(file_path):
|
|
|
+ """从文件中读取提示词"""
|
|
|
+ try:
|
|
|
+ with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
+ return file.read()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"读取提示词文件 {file_path} 失败: {str(e)}")
|
|
|
+ return ""
|
|
|
+CONSISTENCY_PROMPT = read_prompt_file(consistency_prompt_path)
|
|
|
+EVALUATION_PROMPT = read_prompt_file(evaluation_prompt_path)
|
|
|
+EXTRACTION_PROMPT = read_prompt_file(extraction_prompt_path)
|
|
|
+
|
|
|
@tool
|
|
|
def evaluation_extraction_tool(request_id: str, query_word: str) -> str:
|
|
|
"""
|
|
@@ -78,8 +96,6 @@ def execute_continuous_evaluation_extraction(request_id: str, db: Session, query
|
|
|
"""持续执行评估循环,直到数据库没有数据"""
|
|
|
logger.info(f"开始处理,request_id: {request_id}, query_word: {query_word}")
|
|
|
|
|
|
- total_processed = 0
|
|
|
-
|
|
|
try:
|
|
|
while True:
|
|
|
# 分批获取待评估的内容
|
|
@@ -88,35 +104,48 @@ def execute_continuous_evaluation_extraction(request_id: str, db: Session, query
|
|
|
logger.info(f"获取到 {len(contents)} 条待评估内容")
|
|
|
|
|
|
if not contents:
|
|
|
- if total_processed > 0:
|
|
|
- logger.info(f"处理完成,共处理 {total_processed} 条内容")
|
|
|
- db.commit() # 确保最后一批数据被提交
|
|
|
- return "success"
|
|
|
return "no data"
|
|
|
|
|
|
try:
|
|
|
- # 批量评估内容并创建KnowledgeExtractionContent对象
|
|
|
- evaluation_results = batch_evaluate_content(contents, db, request_id, query_word)
|
|
|
+ for content in contents:
|
|
|
+ # 一致性评估
|
|
|
+ logger.info(f"正在进行一致性评估:{content.id}")
|
|
|
+ consistency_result, reason_str = evaluate_consistency(query_word, content.parsing_data)
|
|
|
+
|
|
|
+ extraction_content = KnowledgeExtractionContent(
|
|
|
+ request_id=request_id,
|
|
|
+ parsing_id=content.id,
|
|
|
+ consistency=consistency_result,
|
|
|
+ consistency_reason=reason_str,
|
|
|
+ create_at=datetime.now()
|
|
|
+ )
|
|
|
+
|
|
|
+ if consistency_result == '高':
|
|
|
+ # 内容评估
|
|
|
+ logger.info(f"正在进行内容评估:{content.id}")
|
|
|
+ score, score_reason = evaluate_content(query_word, content.parsing_data)
|
|
|
+ extraction_content.score = score
|
|
|
+ extraction_content.score_reason = score_reason
|
|
|
|
|
|
- # 对评分大于阈值的内容进行抽取
|
|
|
- high_score_results = [result for result in evaluation_results if result["score"] >= SCORE_THRESHOLD]
|
|
|
- if high_score_results:
|
|
|
- logger.info(f"发现 {len(high_score_results)} 条高分内容,进行抽取")
|
|
|
- batch_extract_and_save_content(high_score_results, db, request_id, query_word)
|
|
|
-
|
|
|
- total_processed += len(contents)
|
|
|
+ if score >= SCORE_THRESHOLD:
|
|
|
+ # 清洗提取
|
|
|
+ logger.info(f"正在进行清洗提取:{content.id}")
|
|
|
+ extracted_data, clean_reason = extract_content(query_word, content.parsing_data)
|
|
|
+
|
|
|
+ extraction_content.data = extracted_data
|
|
|
+ extraction_content.clean_reason = clean_reason
|
|
|
+ extraction_content.status = 2
|
|
|
+ db.add(extraction_content)
|
|
|
db.commit() # 每批次处理完成后提交事务
|
|
|
except Exception as e:
|
|
|
# 当前批次处理失败时回滚事务
|
|
|
db.rollback()
|
|
|
logger.error(f"处理批次数据时出错: {e}")
|
|
|
- # 继续处理下一批数据,不需要offset变量,while循环会自动获取下一批数据
|
|
|
except Exception as e:
|
|
|
# 发生严重异常时回滚事务并抛出异常
|
|
|
db.rollback()
|
|
|
logger.error(f"执行评估抽取循环时出错: {e}")
|
|
|
raise
|
|
|
- # 这里的代码永远不会被执行到,因为在while循环中,当contents为空时会返回
|
|
|
|
|
|
def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size: int) -> list:
|
|
|
query = db.query(KnowledgeParsingContent).outerjoin(
|
|
@@ -129,207 +158,106 @@ def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size:
|
|
|
)
|
|
|
|
|
|
return query.limit(batch_size).all()
|
|
|
-
|
|
|
-def batch_evaluate_content(contents: list, db: Session, request_id: str, query_word: str) -> list:
|
|
|
- if not contents:
|
|
|
- return []
|
|
|
-
|
|
|
- try:
|
|
|
- # 批量调用大模型进行评估
|
|
|
- evaluation_results_raw = batch_call_llm_for_evaluation(contents, query_word)
|
|
|
-
|
|
|
- # 处理评估结果
|
|
|
- evaluation_results = []
|
|
|
-
|
|
|
- for i, (parsing_id, score, score_reason, parsing_data, content_id) in enumerate(evaluation_results_raw):
|
|
|
- # 创建KnowledgeExtractionContent对象
|
|
|
- extraction_content = KnowledgeExtractionContent(
|
|
|
- request_id=request_id,
|
|
|
- parsing_id=parsing_id,
|
|
|
- score=score,
|
|
|
- score_reason=score_reason,
|
|
|
- create_at=datetime.now()
|
|
|
- )
|
|
|
- db.add(extraction_content)
|
|
|
-
|
|
|
- evaluation_results.append({
|
|
|
- "parsing_id": parsing_id,
|
|
|
- "score": score,
|
|
|
- "score_reason": score_reason,
|
|
|
- "parsing_data": parsing_data,
|
|
|
- "content_id": content_id,
|
|
|
- "extraction_content": extraction_content
|
|
|
- })
|
|
|
+def evaluate_consistency(keyword, structured_result):
|
|
|
+ """评估一致性"""
|
|
|
+ try:
|
|
|
+ input_data = {
|
|
|
+ "query": keyword,
|
|
|
+ "query结果文本": structured_result
|
|
|
+ }
|
|
|
|
|
|
- return evaluation_results
|
|
|
+ # 调用LLM进行一致性评估
|
|
|
+ result = processor.process(input_data, CONSISTENCY_PROMPT)
|
|
|
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"批量评估内容时出错: {e}")
|
|
|
- # 将所有内容标记为处理失败
|
|
|
- for content in contents:
|
|
|
- content.status = 3 # 处理失败
|
|
|
- return []
|
|
|
-
|
|
|
-def batch_extract_and_save_content(evaluation_results: list, db: Session, request_id: str, query_word: str) -> list:
|
|
|
- if not evaluation_results:
|
|
|
- return []
|
|
|
-
|
|
|
- try:
|
|
|
- # 批量调用大模型进行抽取
|
|
|
- extraction_data_list = batch_call_llm_for_extraction(evaluation_results, query_word)
|
|
|
-
|
|
|
- # 保存抽取结果到数据库
|
|
|
- success_ids = []
|
|
|
- failed_ids = []
|
|
|
-
|
|
|
- for i, (extracted_data, clean_reason) in enumerate(extraction_data_list):
|
|
|
try:
|
|
|
- evaluation_result = evaluation_results[i]
|
|
|
- parsing_id = evaluation_result.get("parsing_id")
|
|
|
-
|
|
|
- if "extraction_content" in evaluation_result and parsing_id:
|
|
|
- # 更新已有对象的data字段和状态
|
|
|
- extraction_content = evaluation_result["extraction_content"]
|
|
|
- extraction_content.data = extracted_data
|
|
|
- extraction_content.clean_reason = clean_reason
|
|
|
- extraction_content.status = 2 # 处理完成
|
|
|
- success_ids.append(parsing_id)
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"处理抽取结果 {i} 时出错: {e}")
|
|
|
- if i < len(evaluation_results):
|
|
|
- failed_ids.append(evaluation_results[i].get("parsing_id"))
|
|
|
-
|
|
|
- # 如果有失败的内容,将其标记为处理失败
|
|
|
- if failed_ids:
|
|
|
- logger.warning(f"有 {len(failed_ids)} 条内容抽取失败")
|
|
|
- for result in evaluation_results:
|
|
|
- if result.get("parsing_id") in failed_ids and "extraction_content" in result:
|
|
|
- result["extraction_content"].status = 3 # 处理失败
|
|
|
-
|
|
|
- return success_ids
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"批量抽取和保存内容时出错: {e}")
|
|
|
- db.rollback() # 确保发生异常时回滚事务
|
|
|
- return []
|
|
|
-
|
|
|
-# 读取提示词文件
|
|
|
-def read_prompt_file(file_path):
|
|
|
- """从文件中读取提示词"""
|
|
|
- try:
|
|
|
- with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
- return file.read()
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"读取提示词文件 {file_path} 失败: {str(e)}")
|
|
|
- return ""
|
|
|
-
|
|
|
-# 初始化 Gemini 处理器和提示词
|
|
|
-# processor = DeepSeekProcessor()
|
|
|
-processor = GeminiProcessor()
|
|
|
-
|
|
|
-
|
|
|
-# 加载评估和抽取提示词
|
|
|
-project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
-evaluation_prompt_path = os.path.join(project_root, 'prompt', 'evaluation.md')
|
|
|
-extraction_prompt_path = os.path.join(project_root, 'prompt', 'extraction.md')
|
|
|
+ # 尝试解析JSON结果
|
|
|
+ # 处理可能的不完整JSON字符串
|
|
|
+ result = result.strip()
|
|
|
+ if result.startswith('```json') and '```' in result:
|
|
|
+ # 提取JSON部分
|
|
|
+ json_str = result.split('```json', 1)[1].split('```', 1)[0].strip()
|
|
|
+ json_result = json.loads(json_str)
|
|
|
+ else:
|
|
|
+ json_result = json.loads(result)
|
|
|
+
|
|
|
+ consistency = json_result.get("consistency", "")
|
|
|
+ reason = json_result.get("reason", [])
|
|
|
+ reason_str = "\n".join(reason) if isinstance(reason, list) else str(reason)
|
|
|
|
|
|
-# 打印路径信息,用于调试
|
|
|
-EVALUATION_PROMPT = read_prompt_file(evaluation_prompt_path)
|
|
|
|
|
|
-EXTRACTION_PROMPT = read_prompt_file(extraction_prompt_path)
|
|
|
+ return consistency, reason_str
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
+ # 如果结果不是有效的JSON,尝试修复并重新解析
|
|
|
+ logger.warning(f"一致性评估结果解析失败: {result[:200]}... 错误: {e}")
|
|
|
+ try:
|
|
|
+ # 尝试从文本中提取JSON部分
|
|
|
+ if '{' in result and '}' in result:
|
|
|
+ json_part = result[result.find('{'):result.rfind('}')+1]
|
|
|
+ json_result = json.loads(json_part)
|
|
|
+ consistency = json_result.get("consistency", "")
|
|
|
+ reason = json_result.get("reason", [])
|
|
|
+ reason_str = "\n".join(reason) if isinstance(reason, list) else str(reason)
|
|
|
+ logger.info(f"修复后解析成功,一致性评估结果: {consistency}")
|
|
|
+ return consistency, reason_str
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+ return "解析错误", result[:500] # 限制返回长度
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"一致性评估过程中发生异常: {e}")
|
|
|
+ return "评估异常", str(e)
|
|
|
|
|
|
-def batch_call_llm_for_evaluation(contents: list, query_word: str) -> list:
|
|
|
- """批量调用大模型进行内容评估
|
|
|
- """
|
|
|
- # 准备批量评估内容
|
|
|
- evaluation_contents = []
|
|
|
- for content in contents:
|
|
|
- evaluation_contents.append({
|
|
|
- "query_word": query_word,
|
|
|
- "content": content.parsing_data,
|
|
|
- "content_id": content.content_id
|
|
|
- })
|
|
|
-
|
|
|
+def evaluate_content(keyword, structured_result):
|
|
|
try:
|
|
|
+ input_data = {
|
|
|
+ "query_word": keyword,
|
|
|
+ "content": structured_result
|
|
|
+ }
|
|
|
# 批量调用 Gemini 进行评估
|
|
|
- results = processor.batch_process(evaluation_contents, EVALUATION_PROMPT)
|
|
|
-
|
|
|
- # 处理返回结果
|
|
|
- evaluation_results = []
|
|
|
- for i, result in enumerate(results):
|
|
|
- try:
|
|
|
- # 只处理大括号外面的内容,保留JSON内部格式
|
|
|
- result = re.sub(r'(^\s*```json)|(\s*```\s*$)', '', result, flags=re.MULTILINE).strip()
|
|
|
- # 尝试修复常见的JSON格式问题
|
|
|
- result = result.replace("'", "\"") # 将单引号替换为双引号
|
|
|
- result = re.sub(r'([{,])\s*(\w+)\s*:', r'\1"\2":', result) # 确保属性名有双引号
|
|
|
-
|
|
|
- # 解析JSON
|
|
|
- parsed_result = json.loads(result)
|
|
|
- parsing_id = contents[i].id
|
|
|
- parsing_data = contents[i].parsing_data
|
|
|
- content_id = contents[i].content_id
|
|
|
- score = parsed_result.get("score", -2)
|
|
|
- score_reason = parsed_result.get("reason", "")
|
|
|
- except Exception as json_error:
|
|
|
- logger.error(f"评估JSON解析错误: {str(json_error)},原始内容: {result[:100]}...")
|
|
|
- parsing_id = contents[i].id
|
|
|
- parsing_data = contents[i].parsing_data
|
|
|
- content_id = contents[i].content_id
|
|
|
- score = -1
|
|
|
- score_reason = f"JSON解析错误: {str(json_error)}"
|
|
|
+ result = processor.process(input_data, EVALUATION_PROMPT)
|
|
|
+ try:
|
|
|
+ # 只处理大括号外面的内容,保留JSON内部格式
|
|
|
+ result = re.sub(r'(^\s*```json)|(\s*```\s*$)', '', result, flags=re.MULTILINE).strip()
|
|
|
+ # 尝试修复常见的JSON格式问题
|
|
|
+ result = result.replace("'", "\"") # 将单引号替换为双引号
|
|
|
+ result = re.sub(r'([{,])\s*(\w+)\s*:', r'\1"\2":', result) # 确保属性名有双引号
|
|
|
|
|
|
- evaluation_results.append((parsing_id, score, score_reason, parsing_data, content_id))
|
|
|
-
|
|
|
- return evaluation_results
|
|
|
-
|
|
|
+ # 解析JSON
|
|
|
+ parsed_result = json.loads(result)
|
|
|
+ score = parsed_result.get("score", -2)
|
|
|
+ score_reason = parsed_result.get("reason", "")
|
|
|
+ return score, score_reason
|
|
|
+ except Exception as json_error:
|
|
|
+ logger.error(f"评估JSON解析错误: {str(json_error)},原始内容: {result[:100]}...")
|
|
|
+ return -1, result[:500] # 限制返回长度
|
|
|
except Exception as e:
|
|
|
- exc_type, exc_value, exc_traceback = sys.exc_info()
|
|
|
- # 提取错误的行号
|
|
|
- tb = traceback.extract_tb(exc_traceback)[-1] # 获取最后一个 traceback(即错误发生的位置)
|
|
|
- line_number = tb.lineno # 行号
|
|
|
- line_content = tb.line # 错误行的代码内容
|
|
|
- logger.error(f"批量评估过程异常: {line_number} 行: {line_content}")
|
|
|
# 返回默认结果
|
|
|
- return [(content.id, 0, "评估过程异常", content.data if hasattr(content, 'data') else (content.parsing_data or "")) for content in contents]
|
|
|
-
|
|
|
-def batch_call_llm_for_extraction(evaluation_results: list, query_word: str) -> list:
|
|
|
- # 准备批量抽取内容
|
|
|
- extraction_contents = []
|
|
|
- for result in evaluation_results:
|
|
|
- parsing_data = result.get("parsing_data", "")
|
|
|
- extraction_contents.append({
|
|
|
- "query_word": query_word,
|
|
|
- "content": parsing_data
|
|
|
- })
|
|
|
+ return -1, f"评估过程异常: {str(e)}"
|
|
|
|
|
|
+def extract_content(keyword, structured_result):
|
|
|
try:
|
|
|
- # 批量调用 Gemini 进行抽取
|
|
|
- results = processor.batch_process(extraction_contents, EXTRACTION_PROMPT)
|
|
|
-
|
|
|
- # 处理返回结果
|
|
|
- extraction_results = []
|
|
|
- for i, result in enumerate(results):
|
|
|
- try:
|
|
|
- # 只处理大括号外面的内容,保留JSON内部格式
|
|
|
- result = re.sub(r'(^\s*```json)|(\s*```\s*$)', '', result, flags=re.MULTILINE).strip()
|
|
|
- # 尝试修复常见的JSON格式问题
|
|
|
- result = result.replace("'", "\"") # 将单引号替换为双引号
|
|
|
- result = re.sub(r'([{,])\s*(\w+)\s*:', r'\1"\2":', result) # 确保属性名有双引号
|
|
|
-
|
|
|
- # 解析JSON
|
|
|
- parsed_result = json.loads(result)
|
|
|
- extracted_data = parsed_result.get("extracted_content", "未提取到内容")
|
|
|
- clean_reason = parsed_result.get("analysis_reason", "未返回原因")
|
|
|
- except Exception as json_error:
|
|
|
- logger.error(f"JSON解析错误: {str(json_error)},原始内容: {result[:100]}...")
|
|
|
- extracted_data = "未提取到内容"
|
|
|
- clean_reason = f"JSON解析错误: {str(json_error)}"
|
|
|
+ input_data = {
|
|
|
+ "query_word": keyword,
|
|
|
+ "content": structured_result
|
|
|
+ }
|
|
|
+ # 批量调用 Gemini 进行评估
|
|
|
+ result = processor.process(input_data, EXTRACTION_PROMPT)
|
|
|
+ try:
|
|
|
+ # 只处理大括号外面的内容,保留JSON内部格式
|
|
|
+ result = re.sub(r'(^\s*```json)|(\s*```\s*$)', '', result, flags=re.MULTILINE).strip()
|
|
|
+ # 尝试修复常见的JSON格式问题
|
|
|
+ result = result.replace("'", "\"") # 将单引号替换为双引号
|
|
|
+ result = re.sub(r'([{,])\s*(\w+)\s*:', r'\1"\2":', result) # 确保属性名有双引号
|
|
|
|
|
|
- extraction_results.append((extracted_data, clean_reason))
|
|
|
-
|
|
|
- return extraction_results
|
|
|
-
|
|
|
+ # 解析JSON
|
|
|
+ parsed_result = json.loads(result)
|
|
|
+ extracted_data = parsed_result.get("extracted_content", "未提取到内容")
|
|
|
+ clean_reason = parsed_result.get("analysis_reason", "未返回原因")
|
|
|
+ return extracted_data, clean_reason
|
|
|
+ except Exception as json_error:
|
|
|
+ logger.error(f"JSON解析错误: {str(json_error)}")
|
|
|
+ extracted_data = "未提取到内容"
|
|
|
+ clean_reason = f"JSON解析错误: {str(json_error)}"
|
|
|
+ return "", clean_reason
|
|
|
except Exception as e:
|
|
|
- logger.error(f"批量抽取过程异常: {str(e)} results:{results}")
|
|
|
- # 返回空结果,确保返回类型为元组列表
|
|
|
- return [("未提取到内容", "抽取过程异常") for _ in range(len(evaluation_results))]
|
|
|
+ # 返回默认结果
|
|
|
+ return "", f"提取过程异常: {str(e)}"
|