structure_processor.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 内容结构化处理模块
  5. 主要功能:
  6. 1. 从数据库中拉取需要结构化的数据
  7. 2. 调用Gemini API进行内容结构化
  8. 3. 将结构化结果更新到数据库
  9. """
  10. import os
  11. import json
  12. import time
  13. import sys
  14. import threading
  15. from typing import Dict, Any, List, Optional, Tuple
  16. # 导入自定义模块
  17. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  18. from utils.mysql_db import MysqlHelper
  19. from gemini import GeminiProcessor
  20. from utils.file import File
  21. from utils.logging_config import get_logger
  22. class StructureProcessor:
  23. def __init__(self):
  24. # 设置日志
  25. self.logger = get_logger('StructureProcessor')
  26. # 初始化处理器
  27. self.processor = GeminiProcessor()
  28. self.system_prompt = File.read_file('../prompt/structure.md')
  29. self.logger.info("系统提示词加载完成")
  30. self.logger.debug(f"系统提示词: {self.system_prompt}")
  31. # 线程控制
  32. self.lock = threading.Lock()
  33. self.stop_event = threading.Event()
  34. self.threads = []
  35. def build_query_conditions(self, query_word: Optional[str],
  36. source_type: Optional[str],
  37. source_channel: Optional[str]) -> Tuple[str, Tuple]:
  38. """构建查询条件和参数"""
  39. conditions = ["multimodal_recognition is not null", "structured_data is null"]
  40. params = []
  41. if query_word is not None:
  42. conditions.append("query_word = %s")
  43. params.append(query_word)
  44. if source_type is not None:
  45. conditions.append("source_type = %s")
  46. params.append(source_type)
  47. if source_channel is not None:
  48. conditions.append("source_channel = %s")
  49. params.append(source_channel)
  50. where_clause = " AND ".join(conditions)
  51. return where_clause, tuple(params)
  52. def process_single_record(self, query_word: Optional[str],
  53. source_type: Optional[str],
  54. source_channel: Optional[str]) -> bool:
  55. """处理单条记录"""
  56. try:
  57. with self.lock:
  58. # 构建查询条件和参数
  59. where_clause, params = self.build_query_conditions(query_word, source_type, source_channel)
  60. # 先查询一条需要处理的记录
  61. select_sql = f"""
  62. SELECT id, multimodal_recognition
  63. FROM knowledge_search_content
  64. WHERE {where_clause}
  65. LIMIT 1
  66. """
  67. records = MysqlHelper.get_values(select_sql, params)
  68. if not records:
  69. self.logger.warning("没有找到需要处理的记录")
  70. return False
  71. row = records[0]
  72. record_id = row[0]
  73. # 标记为处理中,防止其他线程取到重复处理
  74. mark_sql = """
  75. UPDATE knowledge_search_content
  76. SET structured_data = '{}'
  77. WHERE id = %s
  78. """
  79. MysqlHelper.update_values(mark_sql, (record_id,))
  80. self.logger.info(f"开始处理记录 ID: {record_id}")
  81. # 处理内容
  82. result = self.processor.process(row[1], self.system_prompt)
  83. self.logger.info(f"处理完成,结果长度: {len(str(result))}")
  84. self.logger.debug(f"处理结果: {result}")
  85. # 更新数据库为实际结果
  86. update_sql = """
  87. UPDATE knowledge_search_content
  88. SET structured_data = %s
  89. WHERE id = %s
  90. """
  91. MysqlHelper.update_values(update_sql, (result, record_id))
  92. self.logger.info(f"记录 {record_id} 处理完成并更新数据库")
  93. return True
  94. except Exception as e:
  95. self.logger.error(f"处理记录失败: {str(e)}", exc_info=True)
  96. return False
  97. def worker_thread(self, thread_id: int, query_word: Optional[str],
  98. source_type: Optional[str], source_channel: Optional[str]):
  99. """工作线程函数"""
  100. thread_logger = get_logger(f'WorkerThread-{thread_id}')
  101. thread_logger.info(f"线程 {thread_id} 启动")
  102. while not self.stop_event.is_set():
  103. try:
  104. # 尝试处理一条记录
  105. success = self.process_single_record(query_word, source_type, source_channel)
  106. if not success:
  107. thread_logger.info(f"没有找到需要处理的记录,等待5秒后重试")
  108. # 等待时也要检查停止信号
  109. if self.stop_event.wait(5):
  110. break
  111. continue
  112. # 处理成功后等待5秒再处理下一条
  113. thread_logger.info(f"处理完成,等待5秒后处理下一条")
  114. # 等待时也要检查停止信号
  115. if self.stop_event.wait(5):
  116. break
  117. except Exception as e:
  118. thread_logger.error(f"发生错误: {str(e)}", exc_info=True)
  119. # 等待时也要检查停止信号
  120. if self.stop_event.wait(5):
  121. break
  122. thread_logger.info(f"线程 {thread_id} 已停止")
  123. def start_multi_thread_processing(self, query_word: Optional[str],
  124. source_type: Optional[str],
  125. source_channel: Optional[str]):
  126. """启动多线程处理"""
  127. self.threads = []
  128. self.logger.info("启动多线程处理...")
  129. self.logger.info(f"查询条件: query_word={query_word}, source_type={source_type}, source_channel={source_channel}")
  130. # 创建5个线程,间隔5秒启动
  131. for i in range(5):
  132. thread = threading.Thread(
  133. target=self.worker_thread,
  134. args=(i + 1, query_word, source_type, source_channel)
  135. )
  136. self.threads.append(thread)
  137. # 启动线程
  138. thread.start()
  139. self.logger.info(f"线程 {i + 1} 已启动")
  140. # 等待5秒后启动下一个线程
  141. if i < 4: # 最后一个线程不需要等待
  142. self.logger.info("等待5秒后启动下一个线程...")
  143. time.sleep(5)
  144. self.logger.info("所有线程已启动,使用 ./start_structure.sh stop 停止")
  145. try:
  146. # 等待所有线程完成
  147. for thread in self.threads:
  148. thread.join()
  149. except KeyboardInterrupt:
  150. self.logger.info("收到停止信号,正在停止所有线程...")
  151. self.stop_all_threads()
  152. def stop_all_threads(self):
  153. """停止所有线程"""
  154. self.logger.info("正在停止所有线程...")
  155. self.stop_event.set()
  156. # 等待所有线程结束
  157. for i, thread in enumerate(self.threads):
  158. if thread.is_alive():
  159. self.logger.info(f"等待线程 {i + 1} 结束...")
  160. thread.join(timeout=10) # 最多等待10秒
  161. if thread.is_alive():
  162. self.logger.warning(f"线程 {i + 1} 未能正常结束")
  163. else:
  164. self.logger.info(f"线程 {i + 1} 已正常结束")
  165. self.logger.info("所有线程已停止")
  166. def main():
  167. """主函数"""
  168. import argparse
  169. parser = argparse.ArgumentParser(description='内容结构化处理脚本')
  170. parser.add_argument('--query_word', default=None, help='query词')
  171. parser.add_argument('--source_type', default=None, help='数据源类型')
  172. parser.add_argument('--source_channel', default=None, help='数据源渠道')
  173. args = parser.parse_args()
  174. try:
  175. processor = StructureProcessor()
  176. processor.start_multi_thread_processing(
  177. query_word=args.query_word,
  178. source_type=args.source_type,
  179. source_channel=args.source_channel
  180. )
  181. except Exception as e:
  182. print(f"程序执行失败: {str(e)}")
  183. sys.exit(1)
  184. if __name__ == "__main__":
  185. # 测试单条记录处理
  186. processor = StructureProcessor()
  187. processor.process_single_record(query_word=None, source_type=None, source_channel=None)