evaluate.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 内容结构化处理模块
  5. 主要功能:
  6. 1. 从数据库中拉取需要结构化的数据
  7. 2. 调用Gemini API进行内容结构化
  8. 3. 将结构化结果更新到数据库
  9. """
  10. import os
  11. import json
  12. import time
  13. import sys
  14. import re
  15. import threading
  16. from typing import Dict, Any, List, Optional, Tuple
  17. # 导入自定义模块
  18. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  19. from utils.mysql_db import MysqlHelper
  20. from gemini import GeminiProcessor
  21. from utils.file import File
  22. from utils.logging_config import get_logger
  23. class EvaluateProcessor:
  24. def __init__(self):
  25. # 设置日志
  26. self.logger = get_logger('EvaluateProcessor')
  27. # 初始化处理器
  28. self.processor = GeminiProcessor()
  29. self.system_prompt = File.read_file('../prompt/evaluate.md')
  30. self.logger.info("系统提示词加载完成")
  31. self.logger.debug(f"系统提示词: {self.system_prompt}")
  32. # 线程控制
  33. self.lock = threading.Lock()
  34. self.stop_event = threading.Event()
  35. self.threads = []
  36. def build_query_conditions(self, query_word: Optional[str],
  37. source_type: Optional[str],
  38. source_channel: Optional[str]) -> Tuple[str, Tuple]:
  39. """构建查询条件和参数"""
  40. conditions = ["a.structured_data is not null", "b.score is null"]
  41. params = []
  42. if query_word is not None:
  43. conditions.append("a.query_word = %s")
  44. params.append(query_word)
  45. if source_type is not None:
  46. conditions.append("a.source_type = %s")
  47. params.append(source_type)
  48. if source_channel is not None:
  49. conditions.append("a.source_channel = %s")
  50. params.append(source_channel)
  51. where_clause = " AND ".join(conditions)
  52. return where_clause, tuple(params)
  53. def process_single_record(self, query_word: Optional[str],
  54. source_type: Optional[str],
  55. source_channel: Optional[str]) -> bool:
  56. """处理单条记录"""
  57. try:
  58. with self.lock:
  59. # 构建查询条件和参数
  60. where_clause, params = self.build_query_conditions(query_word, source_type, source_channel)
  61. # 先查询一条需要处理的记录
  62. select_sql = f"""
  63. SELECT a.id, a.query_word, a.structured_data
  64. FROM knowledge_search_content a
  65. left join knowledge_content_evaluate b on a.id = b.search_content_id
  66. WHERE {where_clause}
  67. LIMIT 1
  68. """
  69. records = MysqlHelper.get_values(select_sql, params)
  70. if not records:
  71. self.logger.warning("没有找到需要处理的记录")
  72. return False
  73. row = records[0]
  74. record_id = row[0]
  75. # 标记为处理中,防止其他线程取到重复处理
  76. mark_sql = """
  77. insert into knowledge_content_evaluate (search_content_id, score)
  78. values (%s, '-1')
  79. """
  80. MysqlHelper.update_values(mark_sql, (record_id))
  81. self.logger.info(f"开始处理记录 ID: {record_id}")
  82. # 处理内容
  83. user_prompt = f"""
  84. # 任务 (Task)
  85. 现在,请根据以下输入,严格执行你的任务。你的最终输出必须且只能是一个JSON对象。
  86. ## 输入:
  87. Query: {row[1]}
  88. Content: {row[2]}
  89. """
  90. # print(user_prompt)
  91. result = self.processor.process(user_prompt, self.system_prompt)
  92. result = re.sub(r'^\s*```json|\s*```\s*$', '', result, flags=re.MULTILINE).strip()
  93. self.logger.info(f"处理完成,结果长度: {len(str(result))}")
  94. self.logger.info(f"处理结果: {result}")
  95. # 更新数据库为实际结果
  96. update_sql = """
  97. UPDATE knowledge_content_evaluate
  98. SET score = %s ,reason = %s
  99. WHERE search_content_id = %s
  100. """
  101. result = json.loads(result)
  102. score = result['score']
  103. reason = result['reason']
  104. MysqlHelper.update_values(update_sql, (score, reason, record_id))
  105. self.logger.info(f"记录 {record_id} 处理完成并更新数据库")
  106. return True
  107. except Exception as e:
  108. self.logger.error(f"处理记录失败: {str(e)}", exc_info=True)
  109. return False
  110. def worker_thread(self, thread_id: int, query_word: Optional[str],
  111. source_type: Optional[str], source_channel: Optional[str]):
  112. """工作线程函数"""
  113. thread_logger = get_logger(f'WorkerThread-{thread_id}')
  114. thread_logger.info(f"线程 {thread_id} 启动")
  115. while not self.stop_event.is_set():
  116. try:
  117. # 尝试处理一条记录
  118. success = self.process_single_record(query_word, source_type, source_channel)
  119. if not success:
  120. thread_logger.info(f"没有找到需要处理的记录,等待5秒后重试")
  121. # 等待时也要检查停止信号
  122. if self.stop_event.wait(5):
  123. break
  124. continue
  125. # 处理成功后等待5秒再处理下一条
  126. thread_logger.info(f"处理完成,等待5秒后处理下一条")
  127. # 等待时也要检查停止信号
  128. if self.stop_event.wait(5):
  129. break
  130. except Exception as e:
  131. thread_logger.error(f"发生错误: {str(e)}", exc_info=True)
  132. # 等待时也要检查停止信号
  133. if self.stop_event.wait(5):
  134. break
  135. thread_logger.info(f"线程 {thread_id} 已停止")
  136. def start_multi_thread_processing(self, query_word: Optional[str],
  137. source_type: Optional[str],
  138. source_channel: Optional[str]):
  139. """启动多线程处理"""
  140. self.threads = []
  141. self.logger.info("启动多线程处理...")
  142. self.logger.info(f"查询条件: query_word={query_word}, source_type={source_type}, source_channel={source_channel}")
  143. # 创建5个线程,间隔5秒启动
  144. for i in range(5):
  145. thread = threading.Thread(
  146. target=self.worker_thread,
  147. args=(i + 1, query_word, source_type, source_channel)
  148. )
  149. self.threads.append(thread)
  150. # 启动线程
  151. thread.start()
  152. self.logger.info(f"线程 {i + 1} 已启动")
  153. # 等待5秒后启动下一个线程
  154. if i < 4: # 最后一个线程不需要等待
  155. self.logger.info("等待5秒后启动下一个线程...")
  156. time.sleep(5)
  157. self.logger.info("所有线程已启动,使用 ./start_evaluate.sh stop 停止")
  158. try:
  159. # 等待所有线程完成
  160. for thread in self.threads:
  161. thread.join()
  162. except KeyboardInterrupt:
  163. self.logger.info("收到停止信号,正在停止所有线程...")
  164. self.stop_all_threads()
  165. def stop_all_threads(self):
  166. """停止所有线程"""
  167. self.logger.info("正在停止所有线程...")
  168. self.stop_event.set()
  169. # 等待所有线程结束
  170. for i, thread in enumerate(self.threads):
  171. if thread.is_alive():
  172. self.logger.info(f"等待线程 {i + 1} 结束...")
  173. thread.join(timeout=10) # 最多等待10秒
  174. if thread.is_alive():
  175. self.logger.warning(f"线程 {i + 1} 未能正常结束")
  176. else:
  177. self.logger.info(f"线程 {i + 1} 已正常结束")
  178. self.logger.info("所有线程已停止")
  179. def main():
  180. """主函数"""
  181. import argparse
  182. parser = argparse.ArgumentParser(description='内容结构化处理脚本')
  183. parser.add_argument('--query_word', default=None, help='query词')
  184. parser.add_argument('--source_type', default=None, help='数据源类型')
  185. parser.add_argument('--source_channel', default=None, help='数据源渠道')
  186. args = parser.parse_args()
  187. try:
  188. processor = EvaluateProcessor()
  189. processor.start_multi_thread_processing(
  190. query_word=args.query_word,
  191. source_type=args.source_type,
  192. source_channel=args.source_channel
  193. )
  194. except Exception as e:
  195. print(f"程序执行失败: {str(e)}")
  196. sys.exit(1)
  197. if __name__ == "__main__":
  198. # 测试单条记录处理
  199. processor = EvaluateProcessor()
  200. processor.process_single_record(query_word=None, source_type=None, source_channel=None)