gemini.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import os
  2. import time
  3. import json
  4. import re
  5. from google import genai
  6. from google.genai import types
  7. import config
  8. # --- 正确的初始化流程 ---
  9. client = genai.Client(api_key=config.GEMINI_API_KEY)
  10. # 系统提示词和COT配置
  11. DEFAULT_SYSTEM_PROMPT = """
  12. """
  13. SYSTEM_PROMPT_FILE="system_prompt/v37"
  14. def load_system_prompt(prompt_file_path: str) -> str:
  15. """
  16. 从指定文件加载系统提示词
  17. :param prompt_file_path: 系统提示词文件路径
  18. :return: 系统提示词内容
  19. """
  20. try:
  21. with open(prompt_file_path, 'r', encoding='utf-8') as f:
  22. system_prompt = f.read().strip()
  23. print(f"成功从 {prompt_file_path} 加载系统提示词")
  24. return system_prompt
  25. except Exception as e:
  26. print(f"读取系统提示词文件 {prompt_file_path} 出错: {str(e)}")
  27. return DEFAULT_SYSTEM_PROMPT
  28. SYSTEM_PROMPT = load_system_prompt(SYSTEM_PROMPT_FILE)
  29. # 再次提醒:SYSTEM_PROMPT 的内容必须与您期望的输入/输出格式严格匹配。
  30. # 它应该明确说明模型将收到的是 "目标对话\n[目标对话JSON字符串]\n上下文对话\n[上下文JSON数组字符串]" 这种格式,
  31. # 并期望输出为您提供的 { "对话整体解构": {...}, "详细解构": [...] } JSON 对象结构。
  32. def process_files_sequentially(input_dir: str, output_dir: str, num_context_files: int = 4, delay_seconds: float = 2.0):
  33. """
  34. 逐个处理文件夹中的文本文件,每个目标文件带上下文
  35. :param input_dir: 输入文件夹路径
  36. :param output_dir: 输出文件夹路径
  37. :param num_context_files: 每个目标文件附带的上下文文件数量
  38. :param delay_seconds: 每个文件处理之间的延迟(秒)
  39. """
  40. # 确保输出目录存在
  41. os.makedirs(output_dir, exist_ok=True)
  42. # 获取所有txt文件
  43. # 注意: f.endswith('') 会匹配所有文件,如果只想处理txt,应改为 f.endswith('.txt')
  44. input_files_names = sorted([f for f in os.listdir(input_dir) if f.endswith('')])
  45. total_files = len(input_files_names)
  46. print(f"找到 {total_files} 个文件。将逐个处理(每个目标文件附带 {num_context_files} 个上下文文件)")
  47. # 预先读取所有文件内容,以便高效构建上下文
  48. all_file_contents = []
  49. for filename in input_files_names:
  50. input_path = os.path.join(input_dir, filename)
  51. try:
  52. with open(input_path, 'r', encoding='utf-8') as f:
  53. all_file_contents.append(f.read().strip())
  54. except Exception as e:
  55. print(f" ✕ 预读取文件 {filename} 出错: {str(e)}")
  56. all_file_contents.append(f"错误: 无法读取文件 '{filename}' - {str(e)}")
  57. # 逐个处理文件
  58. # i 现在直接代表当前处理文件的索引
  59. for i in range(1):
  60. # for i in range(total_files):
  61. target_filename = input_files_names[i]
  62. target_content = all_file_contents[i]
  63. # 收集上下文文件内容
  64. context_contents = []
  65. for k in range(1, num_context_files + 1):
  66. context_idx = i + k
  67. if context_idx < total_files:
  68. context_contents.append(all_file_contents[context_idx])
  69. # 如果没有足够的上下文文件,就按实际数量提供,不会填充空字符串
  70. print(f"\n处理文件 {i+1}/{total_files}: '{target_filename}' (目标 + {len(context_contents)} 个上下文文件)")
  71. output_path = os.path.join(output_dir, f"{os.path.splitext(target_filename)[0]}.json")
  72. target_content_json_str = json.dumps(target_content, ensure_ascii=False)
  73. context_contents_json_str = json.dumps(context_contents, ensure_ascii=False)
  74. # 构建符合 SYSTEM_PROMPT 期望的单个文本字符串,包含Markdown标题和JSON内容
  75. combined_input_text = (
  76. f"## 目标对话\n"
  77. f"{target_content_json_str}\n" # 使用json.dumps后的字符串
  78. f"## 上下文对话\n"
  79. f"{context_contents_json_str}" # 使用json.dumps后的字符串
  80. )
  81. try:
  82. contents = [
  83. {"text": combined_input_text}
  84. ]
  85. # 调用Gemini模型处理单个目标文件
  86. response = client.models.generate_content(
  87. model="gemini-2.5-pro", # 或者您需要的其他模型
  88. config=types.GenerateContentConfig(
  89. system_instruction=SYSTEM_PROMPT),
  90. contents=contents
  91. )
  92. result = response.text
  93. # 移除Markdown代码块的围栏
  94. result = re.sub(r'^\s*```json\s*|\s*```\s*$', '', result, flags=re.MULTILINE)
  95. result = result.strip() # 去除多余的空行
  96. # 尝试解析JSON响应
  97. try:
  98. # 此时 result 应该就是单个文件的 JSON 结果,即您提供的 { "对话整体解构": {...}, "详细解构": [...] } 结构
  99. dialogue_report = json.loads(result)
  100. print(f" 成功获取并解析API响应 '{target_filename}'")
  101. # 保存处理结果
  102. # dialogue_report 现在是一个字典,可以直接保存
  103. try:
  104. with open(output_path, 'w', encoding='utf-8') as f:
  105. json.dump(dialogue_report, f, indent=2, ensure_ascii=False)
  106. print(f" ✓ 保存: {os.path.basename(output_path)}")
  107. except Exception as e:
  108. error_msg = f"保存错误: {str(e)}"
  109. with open(output_path, 'w', encoding='utf-8') as f:
  110. f.write(error_msg)
  111. print(f" ⚠ 保存失败: {error_msg}")
  112. except json.JSONDecodeError as e:
  113. print(f" ⚠ API返回非JSON格式,尝试提取有效部分... 错误: {e}")
  114. # ****** 重点修改:寻找 '{' 和 '}' 来提取JSON对象 ******
  115. json_start = result.find('{')
  116. json_end = result.rfind('}') + 1 # +1 to include the closing brace
  117. if json_end > json_start > -1: # 检查是否找到了有效的括号对
  118. try:
  119. extracted_report = json.loads(result[json_start:json_end])
  120. print(f" 成功提取JSON数据 for '{target_filename}'")
  121. with open(output_path, 'w', encoding='utf-8') as f:
  122. json.dump(extracted_report, f, indent=2, ensure_ascii=False)
  123. print(f" ✓ 保存 (提取成功): {os.path.basename(output_path)}")
  124. except Exception as extract_e:
  125. error_msg = f"无法提取有效JSON数据,使用原始响应。提取错误: {extract_e}\n原始响应:\n{result}"
  126. with open(output_path, 'w', encoding='utf-8') as f:
  127. f.write(error_msg)
  128. print(f" ⚠ 保存失败 (提取错误): {error_msg}")
  129. else:
  130. error_msg = f"无法定位JSON内容,使用原始响应。\n原始响应:\n{result}"
  131. with open(output_path, 'w', encoding='utf-8') as f:
  132. f.write(error_msg)
  133. print(f" ⚠ 保存失败 (未找到JSON): {error_msg}")
  134. except Exception as e:
  135. error_msg = f"处理 '{target_filename}' 时API错误: {str(e)}"
  136. print(f" ✕ {error_msg}")
  137. # API调用失败,为当前文件生成错误文件
  138. with open(output_path, 'w', encoding='utf-8') as f:
  139. f.write(error_msg)
  140. # 延迟 (在处理完当前文件后,如果不是最后一个文件)
  141. if i < total_files - 1:
  142. print(f"等待 {delay_seconds} 秒...")
  143. time.sleep(delay_seconds)
  144. print("\n所有文件处理完成")