tools.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from langchain.tools import tool
  2. from sqlalchemy.orm import Session
  3. from typing import Dict, Any, Tuple
  4. import logging
  5. from datetime import datetime
  6. import json
  7. import os
  8. import sys
  9. import re
  10. # 添加项目根目录到系统路径
  11. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  12. from database.db import SessionLocal, get_db
  13. from database.models import KnowledgeParsingContent, KnowledgeExtractionContent
  14. from gemini import GeminiProcessor
  15. # 配置日志
  16. logging.basicConfig(level=logging.INFO)
  17. logger = logging.getLogger(__name__)
  18. # 配置常量
  19. BATCH_SIZE = 5 # 分批处理大小
  20. SCORE_THRESHOLD = 70 # 评分阈值
  21. # Define tools
  22. # evaluation_extraction_tool = Tool(
  23. # func=lambda request_id, query_word: _evaluation_extraction_tool(request_id, query_word),
  24. # name="evaluation_extraction_tool",
  25. # description="知识评估与抽取工具,用于处理数据库中的数据,执行评估并抽取内容"
  26. # )
  27. @tool
  28. def evaluation_extraction_tool(request_id: str, query_word: str) -> str:
  29. """
  30. 知识评估与抽取工具。持续处理数据库中的数据,分批执行评估并创建KnowledgeExtractionContent对象。
  31. 对于评分大于70分的内容,会进行抽取并更新KnowledgeExtractionContent对象。
  32. Args:
  33. request_id: 请求ID,如果不提供则处理所有未处理的数据
  34. query_word: 查询词,用于评估和抽取内容
  35. Returns:
  36. str: "success" 表示处理完成,"no data" 表示没有数据需要处理
  37. """
  38. # 使用上下文管理器自动管理数据库连接的生命周期
  39. with SessionLocal() as db:
  40. try:
  41. # 使用新的批量处理函数
  42. result = execute_continuous_evaluation_extraction(request_id, db, query_word)
  43. return result
  44. except Exception as e:
  45. # 确保发生异常时回滚事务
  46. db.rollback()
  47. logger.error(f"评估抽取过程中出错: {e}")
  48. return f"no data - 错误: {str(e)}"
  49. def execute_continuous_evaluation_extraction(request_id: str, db: Session, query_word: str) -> str:
  50. """持续执行评估循环,直到数据库没有数据"""
  51. logger.info(f"开始处理,request_id: {request_id}, query_word: {query_word}")
  52. total_processed = 0
  53. offset = 0
  54. try:
  55. while True:
  56. # 分批获取待评估的内容,使用offset实现分页
  57. contents = get_batch_contents_for_evaluation(request_id, db, BATCH_SIZE, offset)
  58. logger.info(f"获取到 {len(contents)} 条待评估内容")
  59. if not contents:
  60. if total_processed > 0:
  61. logger.info(f"处理完成,共处理 {total_processed} 条内容")
  62. db.commit() # 确保最后一批数据被提交
  63. return "success"
  64. return "no data"
  65. try:
  66. # 批量评估内容并创建KnowledgeExtractionContent对象
  67. evaluation_results = batch_evaluate_content(contents, db, request_id, query_word)
  68. print(f"""evaluation_results: {evaluation_results}""")
  69. # 对评分大于阈值的内容进行抽取
  70. high_score_results = [result for result in evaluation_results if result["score"] >= SCORE_THRESHOLD]
  71. if high_score_results:
  72. logger.info(f"发现 {len(high_score_results)} 条高分内容,进行抽取")
  73. batch_extract_and_save_content(high_score_results, db, request_id, query_word)
  74. total_processed += len(contents)
  75. offset += len(contents) # 更新offset值,以便下次获取下一批数据
  76. db.commit() # 每批次处理完成后提交事务
  77. except Exception as e:
  78. # 当前批次处理失败时回滚事务
  79. db.rollback()
  80. logger.error(f"处理批次数据时出错: {e}")
  81. # 继续处理下一批数据
  82. offset += len(contents)
  83. except Exception as e:
  84. # 发生严重异常时回滚事务并抛出异常
  85. db.rollback()
  86. logger.error(f"执行评估抽取循环时出错: {e}")
  87. raise
  88. # 这里的代码永远不会被执行到,因为在while循环中,当contents为空时会返回
  89. def get_batch_contents_for_evaluation(request_id: str, db: Session, batch_size: int, offset: int = 0) -> list:
  90. query = db.query(KnowledgeParsingContent).outerjoin(
  91. KnowledgeExtractionContent,
  92. KnowledgeParsingContent.id == KnowledgeExtractionContent.parsing_id
  93. ).filter(
  94. KnowledgeParsingContent.status == 2, # 已完成提取的数据
  95. KnowledgeParsingContent.request_id == request_id,
  96. KnowledgeExtractionContent.parsing_id == None
  97. )
  98. return query.offset(offset).limit(batch_size).all()
  99. def batch_evaluate_content(contents: list, db: Session, request_id: str, query_word: str) -> list:
  100. if not contents:
  101. return []
  102. try:
  103. # 批量调用大模型进行评估
  104. evaluation_results_raw = batch_call_llm_for_evaluation(contents, query_word)
  105. # 处理评估结果
  106. evaluation_results = []
  107. for i, (parsing_id, score, score_reason, parsing_data) in enumerate(evaluation_results_raw):
  108. # 创建KnowledgeExtractionContent对象
  109. extraction_content = KnowledgeExtractionContent(
  110. request_id=request_id,
  111. parsing_id=parsing_id,
  112. score=score,
  113. score_reason=score_reason,
  114. create_at=datetime.now()
  115. )
  116. db.add(extraction_content)
  117. evaluation_results.append({
  118. "parsing_id": parsing_id,
  119. "score": score,
  120. "score_reason": score_reason,
  121. "parsing_data": parsing_data,
  122. "extraction_content": extraction_content
  123. })
  124. return evaluation_results
  125. except Exception as e:
  126. logger.error(f"批量评估内容时出错: {e}")
  127. # 将所有内容标记为处理失败
  128. for content in contents:
  129. content.status = 3 # 处理失败
  130. return []
  131. def batch_extract_and_save_content(evaluation_results: list, db: Session, request_id: str, query_word: str) -> list:
  132. if not evaluation_results:
  133. return []
  134. try:
  135. # 批量调用大模型进行抽取
  136. extraction_data_list = batch_call_llm_for_extraction(evaluation_results, query_word)
  137. # 保存抽取结果到数据库
  138. success_ids = []
  139. failed_ids = []
  140. for i, (extracted_data, clean_reason) in enumerate(extraction_data_list):
  141. try:
  142. evaluation_result = evaluation_results[i]
  143. parsing_id = evaluation_result.get("parsing_id")
  144. if "extraction_content" in evaluation_result and parsing_id:
  145. # 更新已有对象的data字段和状态
  146. extraction_content = evaluation_result["extraction_content"]
  147. extraction_content.data = extracted_data
  148. extraction_content.clean_reason = clean_reason
  149. extraction_content.status = 2 # 处理完成
  150. success_ids.append(parsing_id)
  151. except Exception as e:
  152. logger.error(f"处理抽取结果 {i} 时出错: {e}")
  153. if i < len(evaluation_results):
  154. failed_ids.append(evaluation_results[i].get("parsing_id"))
  155. # 如果有失败的内容,将其标记为处理失败
  156. if failed_ids:
  157. logger.warning(f"有 {len(failed_ids)} 条内容抽取失败")
  158. for result in evaluation_results:
  159. if result.get("parsing_id") in failed_ids and "extraction_content" in result:
  160. result["extraction_content"].status = 3 # 处理失败
  161. return success_ids
  162. except Exception as e:
  163. logger.error(f"批量抽取和保存内容时出错: {e}")
  164. db.rollback() # 确保发生异常时回滚事务
  165. return []
  166. # 读取提示词文件
  167. def read_prompt_file(file_path):
  168. """从文件中读取提示词"""
  169. try:
  170. with open(file_path, 'r', encoding='utf-8') as file:
  171. return file.read()
  172. except Exception as e:
  173. logger.error(f"读取提示词文件 {file_path} 失败: {str(e)}")
  174. return ""
  175. # 初始化 Gemini 处理器和提示词
  176. gemini_processor = GeminiProcessor()
  177. # 加载评估和抽取提示词
  178. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  179. evaluation_prompt_path = os.path.join(project_root, 'prompt', 'evaluation.md')
  180. extraction_prompt_path = os.path.join(project_root, 'prompt', 'extraction.md')
  181. # 打印路径信息,用于调试
  182. EVALUATION_PROMPT = read_prompt_file(evaluation_prompt_path)
  183. EXTRACTION_PROMPT = read_prompt_file(extraction_prompt_path)
  184. def batch_call_llm_for_evaluation(contents: list, query_word: str) -> list:
  185. """批量调用大模型进行内容评估
  186. """
  187. # 准备批量评估内容
  188. evaluation_contents = []
  189. for content in contents:
  190. evaluation_contents.append({
  191. "query_word": query_word,
  192. "content": content.parsing_data
  193. })
  194. try:
  195. # 批量调用 Gemini 进行评估
  196. results = gemini_processor.batch_process(evaluation_contents, EVALUATION_PROMPT)
  197. # 处理返回结果
  198. evaluation_results = []
  199. for i, result in enumerate(results):
  200. result = re.sub(r'^\s*```json|\s*```\s*$', '', result, flags=re.MULTILINE).strip()
  201. result = json.loads(result)
  202. parsing_id = contents[i].id
  203. parsing_data = contents[i].parsing_data
  204. score = result.get("score", -2)
  205. score_reason = result.get("reason", "")
  206. evaluation_results.append((parsing_id, score, score_reason, parsing_data))
  207. return evaluation_results
  208. except Exception as e:
  209. logger.error(f"批量评估过程异常: {str(e)}")
  210. # 返回默认结果
  211. return [(content.id, 0, "评估过程异常", content.data if hasattr(content, 'data') else (content.parsing_data or "")) for content in contents]
  212. def batch_call_llm_for_extraction(evaluation_results: list, query_word: str) -> list:
  213. # 准备批量抽取内容
  214. extraction_contents = []
  215. for result in evaluation_results:
  216. parsing_data = result.get("parsing_data", "")
  217. extraction_contents.append({
  218. "query_word": query_word,
  219. "content": parsing_data
  220. })
  221. try:
  222. # 批量调用 Gemini 进行抽取
  223. results = gemini_processor.batch_process(extraction_contents, EXTRACTION_PROMPT)
  224. # 处理返回结果
  225. extraction_results = []
  226. for i, result in enumerate(results):
  227. result = re.sub(r'^\s*```json|\s*```\s*$', '', result, flags=re.MULTILINE).strip()
  228. result = json.loads(result)
  229. extracted_data = result.get("extracted_content", "未提取到内容")
  230. clean_reason = result.get("analysis_reason", "未返回原因")
  231. extraction_results.append((extracted_data, clean_reason))
  232. return extraction_results
  233. except Exception as e:
  234. logger.error(f"批量抽取过程异常: {str(e)}")
  235. # 返回空结果
  236. return ["{}"] * len(evaluation_results)