|
@@ -0,0 +1,249 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+内容结构化处理模块
|
|
|
+主要功能:
|
|
|
+1. 从数据库中拉取需要结构化的数据
|
|
|
+2. 调用Gemini API进行内容结构化
|
|
|
+3. 将结构化结果更新到数据库
|
|
|
+"""
|
|
|
+
|
|
|
+import os
|
|
|
+import json
|
|
|
+import time
|
|
|
+import sys
|
|
|
+import re
|
|
|
+import threading
|
|
|
+from typing import Dict, Any, List, Optional, Tuple
|
|
|
+
|
|
|
+# 导入自定义模块
|
|
|
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
+
|
|
|
+from utils.mysql_db import MysqlHelper
|
|
|
+from gemini import GeminiProcessor
|
|
|
+from utils.file import File
|
|
|
+from utils.logging_config import get_logger
|
|
|
+
|
|
|
+
|
|
|
+class EvaluateProcessor:
|
|
|
+ def __init__(self):
|
|
|
+ # 设置日志
|
|
|
+ self.logger = get_logger('EvaluateProcessor')
|
|
|
+
|
|
|
+ # 初始化处理器
|
|
|
+ self.processor = GeminiProcessor()
|
|
|
+ self.system_prompt = File.read_file('../prompt/evaluate.md')
|
|
|
+ self.logger.info("系统提示词加载完成")
|
|
|
+ self.logger.debug(f"系统提示词: {self.system_prompt}")
|
|
|
+
|
|
|
+ # 线程控制
|
|
|
+ self.lock = threading.Lock()
|
|
|
+ self.stop_event = threading.Event()
|
|
|
+ self.threads = []
|
|
|
+
|
|
|
+ def build_query_conditions(self, query_word: Optional[str],
|
|
|
+ source_type: Optional[str],
|
|
|
+ source_channel: Optional[str]) -> Tuple[str, Tuple]:
|
|
|
+ """构建查询条件和参数"""
|
|
|
+ conditions = ["a.structured_data is not null", "b.score is null"]
|
|
|
+ params = []
|
|
|
+
|
|
|
+ if query_word is not None:
|
|
|
+ conditions.append("a.query_word = %s")
|
|
|
+ params.append(query_word)
|
|
|
+ if source_type is not None:
|
|
|
+ conditions.append("a.source_type = %s")
|
|
|
+ params.append(source_type)
|
|
|
+ if source_channel is not None:
|
|
|
+ conditions.append("a.source_channel = %s")
|
|
|
+ params.append(source_channel)
|
|
|
+
|
|
|
+ where_clause = " AND ".join(conditions)
|
|
|
+ return where_clause, tuple(params)
|
|
|
+
|
|
|
+ def process_single_record(self, query_word: Optional[str],
|
|
|
+ source_type: Optional[str],
|
|
|
+ source_channel: Optional[str]) -> bool:
|
|
|
+ """处理单条记录"""
|
|
|
+ try:
|
|
|
+ with self.lock:
|
|
|
+ # 构建查询条件和参数
|
|
|
+ where_clause, params = self.build_query_conditions(query_word, source_type, source_channel)
|
|
|
+
|
|
|
+ # 先查询一条需要处理的记录
|
|
|
+ select_sql = f"""
|
|
|
+ SELECT a.id, a.query_word, a.structured_data
|
|
|
+ FROM knowledge_search_content a
|
|
|
+ left join knowledge_content_evaluate b on a.id = b.search_content_id
|
|
|
+ WHERE {where_clause}
|
|
|
+ LIMIT 1
|
|
|
+ """
|
|
|
+
|
|
|
+ records = MysqlHelper.get_values(select_sql, params)
|
|
|
+ if not records:
|
|
|
+ self.logger.warning("没有找到需要处理的记录")
|
|
|
+ return False
|
|
|
+
|
|
|
+ row = records[0]
|
|
|
+ record_id = row[0]
|
|
|
+
|
|
|
+ # 标记为处理中,防止其他线程取到重复处理
|
|
|
+ mark_sql = """
|
|
|
+ insert into knowledge_content_evaluate (search_content_id, score)
|
|
|
+ values (%s, '-1')
|
|
|
+ """
|
|
|
+
|
|
|
+ MysqlHelper.update_values(mark_sql, (record_id))
|
|
|
+
|
|
|
+ self.logger.info(f"开始处理记录 ID: {record_id}")
|
|
|
+
|
|
|
+ # 处理内容
|
|
|
+ user_prompt = f"""
|
|
|
+ # 任务 (Task)
|
|
|
+ 现在,请根据以下输入,严格执行你的任务。你的最终输出必须且只能是一个JSON对象。
|
|
|
+ ## 输入:
|
|
|
+ Query: {row[1]}
|
|
|
+ Content: {row[2]}
|
|
|
+ """
|
|
|
+
|
|
|
+ # print(user_prompt)
|
|
|
+
|
|
|
+ result = self.processor.process(user_prompt, self.system_prompt)
|
|
|
+ result = re.sub(r'^\s*```json|\s*```\s*$', '', result, flags=re.MULTILINE).strip()
|
|
|
+ self.logger.info(f"处理完成,结果长度: {len(str(result))}")
|
|
|
+ self.logger.info(f"处理结果: {result}")
|
|
|
+
|
|
|
+
|
|
|
+ # 更新数据库为实际结果
|
|
|
+ update_sql = """
|
|
|
+ UPDATE knowledge_content_evaluate
|
|
|
+ SET score = %s ,reason = %s
|
|
|
+ WHERE search_content_id = %s
|
|
|
+ """
|
|
|
+
|
|
|
+ result = json.loads(result)
|
|
|
+ score = result['score']
|
|
|
+ reason = result['reason']
|
|
|
+
|
|
|
+ MysqlHelper.update_values(update_sql, (score, reason, record_id))
|
|
|
+ self.logger.info(f"记录 {record_id} 处理完成并更新数据库")
|
|
|
+ return True
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"处理记录失败: {str(e)}", exc_info=True)
|
|
|
+ return False
|
|
|
+
|
|
|
+ def worker_thread(self, thread_id: int, query_word: Optional[str],
|
|
|
+ source_type: Optional[str], source_channel: Optional[str]):
|
|
|
+ """工作线程函数"""
|
|
|
+ thread_logger = get_logger(f'WorkerThread-{thread_id}')
|
|
|
+ thread_logger.info(f"线程 {thread_id} 启动")
|
|
|
+
|
|
|
+ while not self.stop_event.is_set():
|
|
|
+ try:
|
|
|
+ # 尝试处理一条记录
|
|
|
+ success = self.process_single_record(query_word, source_type, source_channel)
|
|
|
+
|
|
|
+ if not success:
|
|
|
+ thread_logger.info(f"没有找到需要处理的记录,等待5秒后重试")
|
|
|
+ # 等待时也要检查停止信号
|
|
|
+ if self.stop_event.wait(5):
|
|
|
+ break
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 处理成功后等待5秒再处理下一条
|
|
|
+ thread_logger.info(f"处理完成,等待5秒后处理下一条")
|
|
|
+ # 等待时也要检查停止信号
|
|
|
+ if self.stop_event.wait(5):
|
|
|
+ break
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ thread_logger.error(f"发生错误: {str(e)}", exc_info=True)
|
|
|
+ # 等待时也要检查停止信号
|
|
|
+ if self.stop_event.wait(5):
|
|
|
+ break
|
|
|
+
|
|
|
+ thread_logger.info(f"线程 {thread_id} 已停止")
|
|
|
+
|
|
|
+ def start_multi_thread_processing(self, query_word: Optional[str],
|
|
|
+ source_type: Optional[str],
|
|
|
+ source_channel: Optional[str]):
|
|
|
+ """启动多线程处理"""
|
|
|
+ self.threads = []
|
|
|
+
|
|
|
+ self.logger.info("启动多线程处理...")
|
|
|
+ self.logger.info(f"查询条件: query_word={query_word}, source_type={source_type}, source_channel={source_channel}")
|
|
|
+
|
|
|
+ # 创建5个线程,间隔5秒启动
|
|
|
+ for i in range(5):
|
|
|
+ thread = threading.Thread(
|
|
|
+ target=self.worker_thread,
|
|
|
+ args=(i + 1, query_word, source_type, source_channel)
|
|
|
+ )
|
|
|
+ self.threads.append(thread)
|
|
|
+
|
|
|
+ # 启动线程
|
|
|
+ thread.start()
|
|
|
+ self.logger.info(f"线程 {i + 1} 已启动")
|
|
|
+
|
|
|
+ # 等待5秒后启动下一个线程
|
|
|
+ if i < 4: # 最后一个线程不需要等待
|
|
|
+ self.logger.info("等待5秒后启动下一个线程...")
|
|
|
+ time.sleep(5)
|
|
|
+
|
|
|
+ self.logger.info("所有线程已启动,使用 ./start_evaluate.sh stop 停止")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 等待所有线程完成
|
|
|
+ for thread in self.threads:
|
|
|
+ thread.join()
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ self.logger.info("收到停止信号,正在停止所有线程...")
|
|
|
+ self.stop_all_threads()
|
|
|
+
|
|
|
+ def stop_all_threads(self):
|
|
|
+ """停止所有线程"""
|
|
|
+ self.logger.info("正在停止所有线程...")
|
|
|
+ self.stop_event.set()
|
|
|
+
|
|
|
+ # 等待所有线程结束
|
|
|
+ for i, thread in enumerate(self.threads):
|
|
|
+ if thread.is_alive():
|
|
|
+ self.logger.info(f"等待线程 {i + 1} 结束...")
|
|
|
+ thread.join(timeout=10) # 最多等待10秒
|
|
|
+ if thread.is_alive():
|
|
|
+ self.logger.warning(f"线程 {i + 1} 未能正常结束")
|
|
|
+ else:
|
|
|
+ self.logger.info(f"线程 {i + 1} 已正常结束")
|
|
|
+
|
|
|
+ self.logger.info("所有线程已停止")
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """主函数"""
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(description='内容结构化处理脚本')
|
|
|
+ parser.add_argument('--query_word', default=None, help='query词')
|
|
|
+ parser.add_argument('--source_type', default=None, help='数据源类型')
|
|
|
+ parser.add_argument('--source_channel', default=None, help='数据源渠道')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ try:
|
|
|
+ processor = EvaluateProcessor()
|
|
|
+
|
|
|
+ processor.start_multi_thread_processing(
|
|
|
+ query_word=args.query_word,
|
|
|
+ source_type=args.source_type,
|
|
|
+ source_channel=args.source_channel
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(f"程序执行失败: {str(e)}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 测试单条记录处理
|
|
|
+ processor = EvaluateProcessor()
|
|
|
+ processor.process_single_record(query_word=None, source_type=None, source_channel=None)
|