tools.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from langchain.tools import Tool
  2. from sqlalchemy.orm import Session
  3. from typing import Dict, Any, Tuple
  4. import logging
  5. from datetime import datetime
  6. import json
  7. import os
  8. import sys
  9. # 添加项目根目录到系统路径
  10. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  11. from database.db import SessionLocal, get_db
  12. from database.models import KnowledgeParsingContent, KnowledgeExtractionContent
  13. from gemini import GeminiProcessor
  14. # 配置日志
  15. logging.basicConfig(level=logging.INFO)
  16. logger = logging.getLogger(__name__)
  17. # 配置常量
  18. BATCH_SIZE = 10 # 分批处理大小
  19. SCORE_THRESHOLD = 70 # 评分阈值
  20. # Define tools
  21. @Tool
  22. def evaluation_extraction_tool(request_id: str, query_word: str) -> str:
  23. """
  24. 知识评估与抽取工具。持续处理数据库中的数据,分批执行评估并创建KnowledgeExtractionContent对象。
  25. 对于评分大于70分的内容,会进行抽取并更新KnowledgeExtractionContent对象。
  26. Args:
  27. request_id: 请求ID,如果不提供则处理所有未处理的数据
  28. query_word: 查询词,用于评估和抽取内容
  29. Returns:
  30. str: "success" 表示处理完成,"no data" 表示没有数据需要处理
  31. """
  32. try:
  33. db = SessionLocal()
  34. try:
  35. # 使用新的批量处理函数
  36. result = execute_continuous_evaluation_extraction(request_id, db, query_word)
  37. return result
  38. finally:
  39. db.close()
  40. except Exception as e:
  41. logger.error(f"评估抽取过程中出错: {e}")
  42. return f"no data - 错误: {str(e)}"
  43. def execute_continuous_evaluation_extraction(request_id: str, db: Session, query_word: str) -> str:
  44. """持续执行评估循环,直到数据库没有数据"""
  45. total_processed = 0
  46. while True:
  47. # 分批获取待评估的内容
  48. contents = get_batch_contents_for_evaluation(request_id, db, BATCH_SIZE)
  49. if not contents:
  50. if total_processed > 0:
  51. logger.info(f"处理完成,共处理 {total_processed} 条内容")
  52. return "success"
  53. return "no data"
  54. # 批量评估内容并创建KnowledgeExtractionContent对象
  55. evaluation_results = batch_evaluate_content(contents, db, request_id, query_word)
  56. # 对评分大于阈值的内容进行抽取
  57. high_score_results = [result for result in evaluation_results if result["score"] >= SCORE_THRESHOLD]
  58. if high_score_results:
  59. logger.info(f"发现 {len(high_score_results)} 条高分内容,进行抽取")
  60. batch_extract_and_save_content(high_score_results, db, request_id, query_word)
  61. total_processed += len(contents)
  62. db.commit()
  63. # 这里的代码永远不会被执行到,因为在while循环中,当contents为空时会返回
  64. def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size: int) -> list:
  65. """分批获取待评估的内容"""
  66. query = db.query(KnowledgeParsingContent).filter(
  67. KnowledgeParsingContent.status == 2 # 已完成提取的数据
  68. )
  69. # 如果指定了request_id,则只处理该request_id的数据
  70. if request_id:
  71. query = query.filter(KnowledgeParsingContent.request_id == request_id)
  72. return query.limit(batch_size).all()
  73. def batch_evaluate_content(contents: list, db: Session, request_id: str, query_word: str) -> list:
  74. if not contents:
  75. return []
  76. try:
  77. # 批量调用大模型进行评估
  78. evaluation_results_raw = batch_call_llm_for_evaluation(contents, query_word)
  79. # 处理评估结果
  80. evaluation_results = []
  81. for i, (parsing_id, score, reason, parsing_data) in enumerate(evaluation_results_raw):
  82. # 创建KnowledgeExtractionContent对象
  83. extraction_content = KnowledgeExtractionContent(
  84. request_id=request_id,
  85. parsing_id=parsing_id,
  86. score=score,
  87. reason=reason,
  88. create_at=datetime.now()
  89. )
  90. db.add(extraction_content)
  91. evaluation_results.append({
  92. "parsing_id": parsing_id,
  93. "score": score,
  94. "reason": reason,
  95. "parsing_data": parsing_data,
  96. "extraction_content": extraction_content
  97. })
  98. return evaluation_results
  99. except Exception as e:
  100. logger.error(f"批量评估内容时出错: {e}")
  101. # 将所有内容标记为处理失败
  102. for content in contents:
  103. content.status = 3 # 处理失败
  104. return []
  105. def batch_extract_and_save_content(evaluation_results: list, db: Session, request_id: str, query_word: str) -> list:
  106. if not evaluation_results:
  107. return []
  108. # 批量调用大模型进行抽取
  109. extraction_data_list = batch_call_llm_for_extraction(evaluation_results, query_word)
  110. # 保存抽取结果到数据库
  111. success_ids = []
  112. failed_ids = []
  113. for i, extraction_data in enumerate(extraction_data_list):
  114. try:
  115. evaluation_result = evaluation_results[i]
  116. # 更新已有对象的data字段和状态
  117. existing_extraction.data = evaluation_result["extraction_content"]
  118. existing_extraction.status = 2 # 处理完成
  119. success_ids.append(parsing_id)
  120. except Exception as e:
  121. logger.error(f"处理抽取结果 {i} 时出错: {e}")
  122. failed_ids.append(evaluation_results[i].get("parsing_id"))
  123. # 如果有失败的内容,将其标记为处理失败
  124. if failed_ids:
  125. logger.warning(f"有 {len(failed_ids)} 条内容抽取失败")
  126. for result in evaluation_results:
  127. if result.get("parsing_id") in failed_ids and "extraction_content" in result:
  128. result["extraction_content"].status = 3 # 处理失败
  129. return success_ids
  130. # 读取提示词文件
  131. def read_prompt_file(file_path):
  132. """从文件中读取提示词"""
  133. try:
  134. with open(file_path, 'r', encoding='utf-8') as file:
  135. return file.read()
  136. except Exception as e:
  137. logger.error(f"读取提示词文件 {file_path} 失败: {str(e)}")
  138. return ""
  139. # 初始化 Gemini 处理器和提示词
  140. gemini_processor = GeminiProcessor()
  141. # 加载评估和抽取提示词
  142. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  143. evaluation_prompt_path = os.path.join(project_root, 'prompt', 'evaluation.md')
  144. extraction_prompt_path = os.path.join(project_root, 'prompt', 'extraction.md')
  145. # 打印路径信息,用于调试
  146. logger.info(f"评估提示词路径: {evaluation_prompt_path}")
  147. logger.info(f"抽取提示词路径: {extraction_prompt_path}")
  148. EVALUATION_PROMPT = read_prompt_file(evaluation_prompt_path)
  149. EXTRACTION_PROMPT = read_prompt_file(extraction_prompt_path)
  150. def batch_call_llm_for_evaluation(contents: list, query_word: str) -> list:
  151. """批量调用大模型进行内容评估
  152. """
  153. # 准备批量评估内容
  154. evaluation_contents = []
  155. for content in contents:
  156. evaluation_contents.append({
  157. "query_word": query_word,
  158. "content": content.parsing_data
  159. })
  160. try:
  161. # 批量调用 Gemini 进行评估
  162. results = gemini_processor.batch_process(evaluation_contents, EVALUATION_PROMPT)
  163. # 处理返回结果
  164. evaluation_results = []
  165. for i, result in enumerate(results):
  166. parsing_id = contents[i].id
  167. parsing_data = contents[i].parsing_data
  168. if isinstance(result, dict) and "score" in result:
  169. # 正常结果
  170. score = result.get("score", -2)
  171. reason = result.get("reason", "")
  172. else:
  173. # 异常结果
  174. score = -2
  175. reason = "评估失败"
  176. evaluation_results.append((parsing_id, score, reason, parsing_data))
  177. return evaluation_results
  178. except Exception as e:
  179. logger.error(f"批量评估过程异常: {str(e)}")
  180. # 返回默认结果
  181. return [(content.id, 0, "评估过程异常", content.data if hasattr(content, 'data') else (content.parsing_data or "")) for content in contents]
  182. def batch_call_llm_for_extraction(evaluation_results: list, query_word: str) -> list:
  183. # 准备批量抽取内容
  184. extraction_contents = []
  185. for result in evaluation_results:
  186. parsing_data = result.get("parsing_data", "")
  187. extraction_contents.append({
  188. "query_word": query_word,
  189. "content": parsing_data
  190. })
  191. try:
  192. # 批量调用 Gemini 进行抽取
  193. results = gemini_processor.batch_process(extraction_contents, EXTRACTION_PROMPT)
  194. # 处理返回结果
  195. extraction_results = []
  196. for i, result in enumerate(results):
  197. # 确保结果包含必要的字段
  198. if not isinstance(result, dict):
  199. result = {"extracted_data": str(result)}
  200. extraction_results.append(json.dumps(result, ensure_ascii=False))
  201. return extraction_results
  202. except Exception as e:
  203. logger.error(f"批量抽取过程异常: {str(e)}")
  204. # 返回空结果
  205. return ["{}"] * len(evaluation_results)