tools.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. from langchain_core.tools import tool
  2. from sqlalchemy.orm import Session
  3. from datetime import datetime
  4. import json
  5. import os
  6. import sys
  7. import re
  8. # 添加项目根目录到系统路径
  9. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  10. from gemini import GeminiProcessor
  11. from database.db import SessionLocal, get_db
  12. from database.models import KnowledgeParsingContent, KnowledgeExtractionContent
  13. from utils.logging_config import get_logger
  14. # 配置日志
  15. logger = get_logger('CleanTools')
  16. # 配置常量
  17. BATCH_SIZE = 5 # 分批处理大小
  18. SCORE_THRESHOLD = 70 # 评分阈值
  19. # 初始化 Gemini 处理器和提示词
  20. processor = GeminiProcessor()
  21. # 加载评估和抽取提示词
  22. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  23. consistency_prompt_path = os.path.join(project_root, 'prompt', 'consistency.md')
  24. evaluation_prompt_path = os.path.join(project_root, 'prompt', 'evaluation.md')
  25. extraction_prompt_path = os.path.join(project_root, 'prompt', 'extraction.md')
  26. def read_prompt_file(file_path):
  27. """从文件中读取提示词"""
  28. try:
  29. with open(file_path, 'r', encoding='utf-8') as file:
  30. return file.read()
  31. except Exception as e:
  32. logger.error(f"读取提示词文件 {file_path} 失败: {str(e)}")
  33. return ""
  34. CONSISTENCY_PROMPT = read_prompt_file(consistency_prompt_path)
  35. EVALUATION_PROMPT = read_prompt_file(evaluation_prompt_path)
  36. EXTRACTION_PROMPT = read_prompt_file(extraction_prompt_path)
  37. @tool
  38. def evaluation_extraction_tool(request_id: str, query_word: str) -> str:
  39. """
  40. 知识评估与抽取工具。持续处理数据库中的数据,分批执行评估并创建KnowledgeExtractionContent对象。
  41. 对于评分大于70分的内容,会进行抽取并更新KnowledgeExtractionContent对象。
  42. Args:
  43. request_id: 请求ID,如果不提供则处理所有未处理的数据
  44. query_word: 查询词,用于评估和抽取内容
  45. Returns:
  46. str: "success" 表示处理完成,"no data" 表示没有数据需要处理
  47. """
  48. # 使用上下文管理器自动管理数据库连接的生命周期
  49. with SessionLocal() as db:
  50. try:
  51. # 使用新的批量处理函数
  52. result = execute_continuous_evaluation_extraction(request_id, db, query_word)
  53. return result
  54. except Exception as e:
  55. # 确保发生异常时回滚事务
  56. db.rollback()
  57. logger.error(f"评估抽取过程中出错: {e}")
  58. return f"no data - 错误: {str(e)}"
  59. def evaluation_extraction(request_id: str, query_word: str) -> str:
  60. """
  61. 知识评估与抽取工具。持续处理数据库中的数据,分批执行评估并创建KnowledgeExtractionContent对象。
  62. 对于评分大于70分的内容,会进行抽取并更新KnowledgeExtractionContent对象。
  63. Args:
  64. request_id: 请求ID,如果不提供则处理所有未处理的数据
  65. query_word: 查询词,用于评估和抽取内容
  66. Returns:
  67. str: "success" 表示处理完成,"no data" 表示没有数据需要处理
  68. """
  69. # 使用上下文管理器自动管理数据库连接的生命周期
  70. with SessionLocal() as db:
  71. try:
  72. # 使用新的批量处理函数
  73. result = execute_continuous_evaluation_extraction(request_id, db, query_word)
  74. return result
  75. except Exception as e:
  76. # 确保发生异常时回滚事务
  77. db.rollback()
  78. logger.error(f"评估抽取过程中出错: {e}")
  79. return f"no data - 错误: {str(e)}"
  80. def execute_continuous_evaluation_extraction(request_id: str, db: Session, query_word: str) -> str:
  81. """持续执行评估循环,直到数据库没有数据"""
  82. logger.info(f"开始处理,request_id: {request_id}, query_word: {query_word}")
  83. try:
  84. while True:
  85. # 分批获取待评估的内容
  86. contents = get_batch_contents_for_evaluation(request_id, db, BATCH_SIZE)
  87. logger.info(f"获取到 {len(contents)} 条待评估内容")
  88. if not contents:
  89. return "no data"
  90. try:
  91. for content in contents:
  92. # 一致性评估
  93. logger.info(f"正在进行一致性评估:{content.id}")
  94. consistency_result, reason_str = evaluate_consistency(query_word, content.parsing_data)
  95. extraction_content = KnowledgeExtractionContent(
  96. request_id=request_id,
  97. parsing_id=content.id,
  98. consistency=consistency_result,
  99. consistency_reason=reason_str,
  100. create_at=datetime.now()
  101. )
  102. if consistency_result == '高':
  103. # 内容评估
  104. logger.info(f"正在进行内容评估:{content.id}")
  105. score, score_reason = evaluate_content(query_word, content.parsing_data)
  106. extraction_content.score = score
  107. extraction_content.score_reason = score_reason
  108. if score >= SCORE_THRESHOLD:
  109. # 清洗提取
  110. logger.info(f"正在进行清洗提取:{content.id}")
  111. extracted_data, clean_reason = extract_content(query_word, content.parsing_data)
  112. extraction_content.data = extracted_data
  113. extraction_content.clean_reason = clean_reason
  114. extraction_content.status = 2
  115. db.add(extraction_content)
  116. db.commit() # 每批次处理完成后提交事务
  117. except Exception as e:
  118. # 当前批次处理失败时回滚事务
  119. db.rollback()
  120. logger.error(f"处理批次数据时出错: {e}")
  121. except Exception as e:
  122. # 发生严重异常时回滚事务并抛出异常
  123. db.rollback()
  124. logger.error(f"执行评估抽取循环时出错: {e}")
  125. raise
  126. def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size: int) -> list:
  127. query = db.query(KnowledgeParsingContent).outerjoin(
  128. KnowledgeExtractionContent,
  129. KnowledgeParsingContent.id == KnowledgeExtractionContent.parsing_id
  130. ).filter(
  131. KnowledgeParsingContent.status == 5, # 已完成提取的数据
  132. KnowledgeParsingContent.request_id == request_id,
  133. KnowledgeExtractionContent.parsing_id == None
  134. )
  135. return query.limit(batch_size).all()
  136. def evaluate_consistency(keyword, structured_result):
  137. """评估一致性"""
  138. try:
  139. input_data = {
  140. "query": keyword,
  141. "query结果文本": structured_result
  142. }
  143. # 调用LLM进行一致性评估
  144. result = processor.process(input_data, CONSISTENCY_PROMPT)
  145. try:
  146. # 尝试解析JSON结果
  147. # 处理可能的不完整JSON字符串
  148. result = result.strip()
  149. if result.startswith('```json') and '```' in result:
  150. # 提取JSON部分
  151. json_str = result.split('```json', 1)[1].split('```', 1)[0].strip()
  152. json_result = json.loads(json_str)
  153. else:
  154. json_result = json.loads(result)
  155. consistency = json_result.get("consistency", "")
  156. reason = json_result.get("reason", [])
  157. reason_str = "\n".join(reason) if isinstance(reason, list) else str(reason)
  158. return consistency, reason_str
  159. except json.JSONDecodeError as e:
  160. # 如果结果不是有效的JSON,尝试修复并重新解析
  161. logger.warning(f"一致性评估结果解析失败: {result[:200]}... 错误: {e}")
  162. try:
  163. # 尝试从文本中提取JSON部分
  164. if '{' in result and '}' in result:
  165. json_part = result[result.find('{'):result.rfind('}')+1]
  166. json_result = json.loads(json_part)
  167. consistency = json_result.get("consistency", "")
  168. reason = json_result.get("reason", [])
  169. reason_str = "\n".join(reason) if isinstance(reason, list) else str(reason)
  170. logger.info(f"修复后解析成功,一致性评估结果: {consistency}")
  171. return consistency, reason_str
  172. except:
  173. pass
  174. return "解析错误", result[:500] # 限制返回长度
  175. except Exception as e:
  176. logger.error(f"一致性评估过程中发生异常: {e}")
  177. return "评估异常", str(e)
  178. def evaluate_content(keyword, structured_result):
  179. try:
  180. input_data = {
  181. "query_word": keyword,
  182. "content": structured_result
  183. }
  184. # 批量调用 Gemini 进行评估
  185. result = processor.process(input_data, EVALUATION_PROMPT)
  186. try:
  187. # 只处理大括号外面的内容,保留JSON内部格式
  188. result = re.sub(r'(^\s*```json)|(\s*```\s*$)', '', result, flags=re.MULTILINE).strip()
  189. # 尝试修复常见的JSON格式问题
  190. result = result.replace("'", "\"") # 将单引号替换为双引号
  191. result = re.sub(r'([{,])\s*(\w+)\s*:', r'\1"\2":', result) # 确保属性名有双引号
  192. # 解析JSON
  193. parsed_result = json.loads(result)
  194. score = parsed_result.get("score", -2)
  195. score_reason = parsed_result.get("reason", "")
  196. return score, score_reason
  197. except Exception as json_error:
  198. logger.error(f"评估JSON解析错误: {str(json_error)},原始内容: {result[:100]}...")
  199. return -1, result[:500] # 限制返回长度
  200. except Exception as e:
  201. # 返回默认结果
  202. return -1, f"评估过程异常: {str(e)}"
  203. def extract_content(keyword, structured_result):
  204. try:
  205. input_data = {
  206. "query_word": keyword,
  207. "content": structured_result
  208. }
  209. # 批量调用 Gemini 进行评估
  210. result = processor.process(input_data, EXTRACTION_PROMPT)
  211. try:
  212. # 只处理大括号外面的内容,保留JSON内部格式
  213. result = re.sub(r'(^\s*```json)|(\s*```\s*$)', '', result, flags=re.MULTILINE).strip()
  214. # 尝试修复常见的JSON格式问题
  215. result = result.replace("'", "\"") # 将单引号替换为双引号
  216. result = re.sub(r'([{,])\s*(\w+)\s*:', r'\1"\2":', result) # 确保属性名有双引号
  217. # 解析JSON
  218. parsed_result = json.loads(result)
  219. extracted_data = parsed_result.get("extracted_content", "未提取到内容")
  220. clean_reason = parsed_result.get("analysis_reason", "未返回原因")
  221. return extracted_data, clean_reason
  222. except Exception as json_error:
  223. logger.error(f"JSON解析错误: {str(json_error)}")
  224. extracted_data = "未提取到内容"
  225. clean_reason = f"JSON解析错误: {str(json_error)}"
  226. return "", clean_reason
  227. except Exception as e:
  228. # 返回默认结果
  229. return "", f"提取过程异常: {str(e)}"