123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 内容结构化处理模块
- 主要功能:
- 1. 从数据库中拉取需要结构化的数据
- 2. 调用Gemini API进行内容结构化
- 3. 将结构化结果更新到数据库
- """
- import os
- import json
- import time
- import sys
- import re
- import threading
- from typing import Dict, Any, List, Optional, Tuple
- # 导入自定义模块
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- from utils.mysql_db import MysqlHelper
- from gemini import GeminiProcessor
- from utils.file import File
- from utils.logging_config import get_logger
- class EvaluateProcessor:
- def __init__(self):
- # 设置日志
- self.logger = get_logger('EvaluateProcessor')
-
- # 初始化处理器
- self.processor = GeminiProcessor()
- self.system_prompt = File.read_file('../prompt/evaluate.md')
- self.logger.info("系统提示词加载完成")
- self.logger.debug(f"系统提示词: {self.system_prompt}")
-
- # 线程控制
- self.lock = threading.Lock()
- self.stop_event = threading.Event()
- self.threads = []
-
- def build_query_conditions(self, query_word: Optional[str],
- source_type: Optional[str],
- source_channel: Optional[str]) -> Tuple[str, Tuple]:
- """构建查询条件和参数"""
- conditions = ["a.structured_data is not null", "b.score is null", "c.category_id = 0"]
- params = []
-
- if query_word is not None:
- conditions.append("a.query_word = %s")
- params.append(query_word)
- if source_type is not None:
- conditions.append("a.source_type = %s")
- params.append(source_type)
- if source_channel is not None:
- conditions.append("a.source_channel = %s")
- params.append(source_channel)
-
- where_clause = " AND ".join(conditions)
- return where_clause, tuple(params)
-
- def process_single_record(self, query_word: Optional[str],
- source_type: Optional[str],
- source_channel: Optional[str]) -> bool:
- """处理单条记录"""
- try:
- with self.lock:
- # 构建查询条件和参数
- where_clause, params = self.build_query_conditions(query_word, source_type, source_channel)
-
- # 先查询一条需要处理的记录
- select_sql = f"""
- SELECT a.id, a.query_word, a.structured_data
- FROM knowledge_search_content a
- left join knowledge_content_evaluate b on a.id = b.search_content_id
- LEFT JOIN knowledge_content_query c ON a.query_word = c.query_word
- WHERE {where_clause}
- LIMIT 1
- """
-
- records = MysqlHelper.get_values(select_sql, params)
- if not records:
- self.logger.warning("没有找到需要处理的记录")
- return False
-
- row = records[0]
- record_id = row[0]
-
- # 标记为处理中,防止其他线程取到重复处理
- mark_sql = """
- insert into knowledge_content_evaluate (search_content_id, score)
- values (%s, '-1')
- """
-
- MysqlHelper.update_values(mark_sql, (record_id))
-
- self.logger.info(f"开始处理记录 ID: {record_id}")
-
- # 处理内容
- user_prompt = f"""
- # 任务 (Task)
- 现在,请根据以下输入,严格执行你的任务。你的最终输出必须且只能是一个JSON对象。
- ## 输入:
- Query: {row[1]}
- Content: {row[2]}
- """
- # print(user_prompt)
- result = self.processor.process(user_prompt, self.system_prompt)
- result = re.sub(r'^\s*```json|\s*```\s*$', '', result, flags=re.MULTILINE).strip()
- self.logger.info(f"处理完成,结果长度: {len(str(result))}")
- self.logger.info(f"处理结果: {result}")
-
- # 更新数据库为实际结果
- update_sql = """
- UPDATE knowledge_content_evaluate
- SET score = %s ,reason = %s
- WHERE search_content_id = %s
- """
- result = json.loads(result)
- score = result['score']
- reason = result['reason']
-
- MysqlHelper.update_values(update_sql, (score, reason, record_id))
- self.logger.info(f"记录 {record_id} 处理完成并更新数据库")
- return True
-
- except Exception as e:
- self.logger.error(f"处理记录失败: {str(e)}", exc_info=True)
- return False
-
- def worker_thread(self, thread_id: int, query_word: Optional[str],
- source_type: Optional[str], source_channel: Optional[str]):
- """工作线程函数"""
- thread_logger = get_logger(f'WorkerThread-{thread_id}')
- thread_logger.info(f"线程 {thread_id} 启动")
-
- while not self.stop_event.is_set():
- try:
- # 尝试处理一条记录
- success = self.process_single_record(query_word, source_type, source_channel)
-
- if not success:
- thread_logger.info(f"没有找到需要处理的记录,等待5秒后重试")
- # 等待时也要检查停止信号
- if self.stop_event.wait(5):
- break
- continue
-
- # 处理成功后等待5秒再处理下一条
- thread_logger.info(f"处理完成,等待5秒后处理下一条")
- # 等待时也要检查停止信号
- if self.stop_event.wait(5):
- break
-
- except Exception as e:
- thread_logger.error(f"发生错误: {str(e)}", exc_info=True)
- # 等待时也要检查停止信号
- if self.stop_event.wait(5):
- break
-
- thread_logger.info(f"线程 {thread_id} 已停止")
-
- def start_multi_thread_processing(self, query_word: Optional[str],
- source_type: Optional[str],
- source_channel: Optional[str]):
- """启动多线程处理"""
- self.threads = []
-
- self.logger.info("启动多线程处理...")
- self.logger.info(f"查询条件: query_word={query_word}, source_type={source_type}, source_channel={source_channel}")
-
- # 创建5个线程,间隔5秒启动
- for i in range(5):
- thread = threading.Thread(
- target=self.worker_thread,
- args=(i + 1, query_word, source_type, source_channel)
- )
- self.threads.append(thread)
-
- # 启动线程
- thread.start()
- self.logger.info(f"线程 {i + 1} 已启动")
-
- # 等待5秒后启动下一个线程
- if i < 4: # 最后一个线程不需要等待
- self.logger.info("等待5秒后启动下一个线程...")
- time.sleep(5)
-
- self.logger.info("所有线程已启动,使用 ./start_evaluate.sh stop 停止")
-
- try:
- # 等待所有线程完成
- for thread in self.threads:
- thread.join()
- except KeyboardInterrupt:
- self.logger.info("收到停止信号,正在停止所有线程...")
- self.stop_all_threads()
-
- def stop_all_threads(self):
- """停止所有线程"""
- self.logger.info("正在停止所有线程...")
- self.stop_event.set()
-
- # 等待所有线程结束
- for i, thread in enumerate(self.threads):
- if thread.is_alive():
- self.logger.info(f"等待线程 {i + 1} 结束...")
- thread.join(timeout=10) # 最多等待10秒
- if thread.is_alive():
- self.logger.warning(f"线程 {i + 1} 未能正常结束")
- else:
- self.logger.info(f"线程 {i + 1} 已正常结束")
-
- self.logger.info("所有线程已停止")
- def main():
- """主函数"""
- import argparse
-
- parser = argparse.ArgumentParser(description='内容结构化处理脚本')
- parser.add_argument('--query_word', default=None, help='query词')
- parser.add_argument('--source_type', default=None, help='数据源类型')
- parser.add_argument('--source_channel', default=None, help='数据源渠道')
-
- args = parser.parse_args()
-
- try:
- processor = EvaluateProcessor()
-
- processor.start_multi_thread_processing(
- query_word=args.query_word,
- source_type=args.source_type,
- source_channel=args.source_channel
- )
- except Exception as e:
- print(f"程序执行失败: {str(e)}")
- sys.exit(1)
- if __name__ == "__main__":
- # 测试单条记录处理
- processor = EvaluateProcessor()
- processor.process_single_record(query_word=None, source_type=None, source_channel=None)
|