|
@@ -3,7 +3,7 @@ import json
|
|
|
import time
|
|
|
import sys
|
|
|
import argparse
|
|
|
-from typing import Dict, Any, List, Optional
|
|
|
+from typing import Dict, Any, List, Optional, Tuple
|
|
|
|
|
|
# 导入自定义模块
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
@@ -15,85 +15,80 @@ from utils.file import File
|
|
|
|
|
|
class Handler:
|
|
|
def __init__(self):
|
|
|
-
|
|
|
- # 初始化飞书客户端
|
|
|
+ # 初始化处理器
|
|
|
self.processor = GeminiProcessor()
|
|
|
self.system_prompt = File.read_file('prompt/handle.md')
|
|
|
-
|
|
|
- # print(self.system_prompt)
|
|
|
+
|
|
|
+ def build_query_conditions(self, query_word: Optional[str],
|
|
|
+ source_type: Optional[str],
|
|
|
+ source_channel: Optional[str]) -> Tuple[str, Tuple]:
|
|
|
+ """构建查询条件和参数"""
|
|
|
+ conditions = ["formatted_content is not null", "multimodal_recognition is null"]
|
|
|
+ params = []
|
|
|
|
|
|
+ if query_word is not None:
|
|
|
+ conditions.append("query_word = %s")
|
|
|
+ params.append(query_word)
|
|
|
+ if source_type is not None:
|
|
|
+ conditions.append("source_type = %s")
|
|
|
+ params.append(source_type)
|
|
|
+ if source_channel is not None:
|
|
|
+ conditions.append("source_channel = %s")
|
|
|
+ params.append(source_channel)
|
|
|
+
|
|
|
+ where_clause = " AND ".join(conditions)
|
|
|
+ return where_clause, tuple(params)
|
|
|
|
|
|
- def process_all_records(self, query_word, source_type, source_channel):
|
|
|
+ def process_all_records(self, query_word: Optional[str],
|
|
|
+ source_type: Optional[str],
|
|
|
+ source_channel: Optional[str]):
|
|
|
"""处理所有记录"""
|
|
|
-
|
|
|
total_processed = 0
|
|
|
total_success = 0
|
|
|
|
|
|
- while True:
|
|
|
- try:
|
|
|
- # 查库 获取记录
|
|
|
- sql = """
|
|
|
- select id, formatted_content from knowledge_search_content
|
|
|
- where formatted_content is not null and multimodal_recognition is null
|
|
|
- """
|
|
|
-
|
|
|
- # 添加条件(当参数有值时)
|
|
|
- conditions = []
|
|
|
- if query_word is not None:
|
|
|
- conditions.append(f"query_word='{query_word}'")
|
|
|
- if source_type is not None:
|
|
|
- conditions.append(f"source_type='{source_type}'")
|
|
|
- if source_channel is not None:
|
|
|
- conditions.append(f"source_channel='{source_channel}'")
|
|
|
-
|
|
|
- # 如果有条件,添加到SQL中
|
|
|
- if conditions:
|
|
|
- sql += " and " + " and ".join(conditions)
|
|
|
-
|
|
|
- records = MysqlHelper.get_values(sql)
|
|
|
- print(f"获取到 {len(result)} 条记录")
|
|
|
-
|
|
|
-
|
|
|
- # 处理每条记录
|
|
|
- for row in records:
|
|
|
- total_processed += 1
|
|
|
- """处理单条记录"""
|
|
|
- try:
|
|
|
-
|
|
|
- result = self.processor.process(row[1], self.system_prompt)
|
|
|
-
|
|
|
-
|
|
|
- # 更新数据库
|
|
|
- update_sql = """
|
|
|
- update knowledge_search_content set multimodal_recognition = %s where id = %s
|
|
|
- """
|
|
|
- MysqlHelper.update_values(update_sql, (result, row[0]))
|
|
|
-
|
|
|
-
|
|
|
- # 添加延迟避免API限制
|
|
|
- time.sleep(1)
|
|
|
- total_success += 1
|
|
|
- except Exception as e:
|
|
|
- print(f"处理记录 {record.record_id} 失败: {e}")
|
|
|
-
|
|
|
-
|
|
|
- # 检查是否有下一页
|
|
|
- if not result.has_more:
|
|
|
- break
|
|
|
-
|
|
|
- page_token = result.page_token
|
|
|
- print(f"继续获取下一页,token: {page_token}")
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"获取记录失败: {e}")
|
|
|
- break
|
|
|
-
|
|
|
- print(f"处理完成!总共处理 {total_processed} 条记录,成功 {total_success} 条")
|
|
|
+ try:
|
|
|
+ # 构建查询条件和参数
|
|
|
+ where_clause, params = self.build_query_conditions(query_word, source_type, source_channel)
|
|
|
+ sql = f"""
|
|
|
+ SELECT id, formatted_content
|
|
|
+ FROM knowledge_search_content
|
|
|
+ WHERE {where_clause}
|
|
|
+ """
|
|
|
+
|
|
|
+ # 查询记录
|
|
|
+ records = MysqlHelper.get_values(sql, params)
|
|
|
+ print(f"获取到 {len(records)} 条记录")
|
|
|
+
|
|
|
+ # 处理每条记录
|
|
|
+ for row in records:
|
|
|
+ total_processed += 1
|
|
|
+ try:
|
|
|
+ # 处理内容
|
|
|
+ result = self.processor.process(row[1], self.system_prompt)
|
|
|
+
|
|
|
+ # 更新数据库
|
|
|
+ update_sql = """
|
|
|
+ UPDATE knowledge_search_content
|
|
|
+ SET multimodal_recognition = %s
|
|
|
+ WHERE id = %s
|
|
|
+ """
|
|
|
+ MysqlHelper.update_values(update_sql, (result, row[0]))
|
|
|
+
|
|
|
+ # 添加延迟避免API限制
|
|
|
+ time.sleep(1)
|
|
|
+ total_success += 1
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"处理记录 {row[0]} 失败: {str(e)}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"处理过程中发生错误: {str(e)}")
|
|
|
+ finally:
|
|
|
+ print(f"处理完成!总共处理 {total_processed} 条记录,成功 {total_success} 条")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""主函数"""
|
|
|
- # 创建命令行参数解析器
|
|
|
parser = argparse.ArgumentParser(description='内容识别脚本')
|
|
|
parser.add_argument('--query_word', default=None, help='query词')
|
|
|
parser.add_argument('--source_type', default=None, help='数据源类型')
|
|
@@ -102,17 +97,14 @@ def main():
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
try:
|
|
|
- # 创建内容识别器实例
|
|
|
handler = Handler()
|
|
|
-
|
|
|
handler.process_all_records(
|
|
|
query_word=args.query_word,
|
|
|
source_type=args.source_type,
|
|
|
source_channel=args.source_channel
|
|
|
)
|
|
|
-
|
|
|
except Exception as e:
|
|
|
- print(f"程序执行失败: {e}")
|
|
|
+ print(f"程序执行失败: {str(e)}")
|
|
|
sys.exit(1)
|
|
|
|
|
|
|