audio_identifier.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 音频识别脚本
  5. 主要功能:将音频转文字(ASR),参考视频识别模块的结构实现
  6. 支持从 formatted_content 中提取音频 URL,下载后上传至 Gemini 进行转写。
  7. """
  8. import os
  9. import json
  10. import time
  11. import sys
  12. import uuid
  13. import requests
  14. from typing import Dict, Any, List, Optional
  15. from dotenv import load_dotenv
  16. # 导入自定义模块
  17. sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  18. from utils.logging_config import get_logger
  19. # 创建 logger
  20. logger = get_logger('AudioIdentifier')
  21. # 导入Google Generative AI
  22. import google.generativeai as genai
  23. from google.generativeai.types import HarmCategory, HarmBlockThreshold
  24. # 缓存目录配置
  25. CACHE_DIR = os.path.join(os.path.dirname(__file__), 'cache')
  26. # 缓存文件最大保留时间(秒)
  27. CACHE_MAX_AGE = 3600 # 1小时
  28. class AudioIdentifier:
  29. def __init__(self):
  30. # 加载环境变量
  31. load_dotenv()
  32. # 延迟配置Gemini,在真正使用时再设置
  33. self._configured = False
  34. self.model = None
  35. self.api_key = None
  36. # 初始化缓存清理时间
  37. self.last_cache_cleanup = time.time()
  38. # 系统提示词:仅做语音转文字
  39. self.system_prompt = (
  40. "你是一个专业的音频转写助手。请严格将音频中的语音内容转换为文字,不要添加任何分析、解释或评论。"
  41. )
  42. def _ensure_configured(self):
  43. """确保Gemini已配置"""
  44. if not self._configured:
  45. # 与图片模块保持一致读取 GEMINI_API_KEY_1,若无则回退 GEMINI_API_KEY
  46. self.api_key = os.getenv('GEMINI_API_KEY') or os.getenv('GEMINI_API_KEY_1')
  47. if not self.api_key:
  48. raise ValueError('请在环境变量中设置 GEMINI_API_KEY 或 GEMINI_API_KEY_1')
  49. genai.configure(api_key=self.api_key)
  50. # 使用通用多模态模型进行音频理解
  51. self.model = genai.GenerativeModel(
  52. model_name='gemini-2.5-flash',
  53. generation_config=genai.GenerationConfig(
  54. response_mime_type='text/plain',
  55. temperature=0.2,
  56. max_output_tokens=409600
  57. ),
  58. safety_settings={
  59. HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
  60. HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
  61. }
  62. )
  63. self._configured = True
  64. def cleanup_cache(self):
  65. """清理过期的缓存文件"""
  66. try:
  67. current_time = time.time()
  68. if current_time - self.last_cache_cleanup < 3600:
  69. return
  70. if not os.path.exists(CACHE_DIR):
  71. return
  72. cleaned_count = 0
  73. for filename in os.listdir(CACHE_DIR):
  74. file_path = os.path.join(CACHE_DIR, filename)
  75. if os.path.isfile(file_path):
  76. file_age = current_time - os.path.getmtime(file_path)
  77. if file_age > CACHE_MAX_AGE:
  78. try:
  79. os.remove(file_path)
  80. cleaned_count += 1
  81. except Exception as e:
  82. logger.warning(f'清理缓存文件失败: {file_path}, 错误: {e}')
  83. if cleaned_count > 0:
  84. logger.info(f'已清理 {cleaned_count} 个过期缓存文件')
  85. self.last_cache_cleanup = current_time
  86. except Exception as e:
  87. logger.error(f'清理缓存失败: {e}')
  88. def download_audio(self, audio_url: str) -> Optional[str]:
  89. """下载音频到本地缓存并返回路径"""
  90. # 猜测常见音频类型,后续统一按 mp3 保存
  91. file_path = os.path.join(CACHE_DIR, f'{str(uuid.uuid4())}.mp3')
  92. try:
  93. os.makedirs(CACHE_DIR, exist_ok=True)
  94. except Exception as e:
  95. logger.error(f'创建缓存目录失败: {e}')
  96. return None
  97. try:
  98. for attempt in range(3):
  99. try:
  100. response = requests.get(url=audio_url, timeout=60)
  101. if response.status_code == 200:
  102. try:
  103. with open(file_path, 'wb') as f:
  104. f.write(response.content)
  105. return file_path
  106. except Exception as e:
  107. logger.error(f'音频保存失败: {e}')
  108. if os.path.exists(file_path):
  109. try:
  110. os.remove(file_path)
  111. except Exception:
  112. pass
  113. return None
  114. else:
  115. logger.warning(f'音频下载失败,状态码: {response.status_code}')
  116. if attempt == 2:
  117. return None
  118. except Exception as e:
  119. logger.warning(f'下载尝试 {attempt + 1} 失败: {e}')
  120. if attempt < 2:
  121. time.sleep(1)
  122. continue
  123. return None
  124. except Exception as e:
  125. logger.error(f'下载过程异常: {e}')
  126. return None
  127. return None
  128. def upload_audio_to_gemini(self, audio_path: str) -> Optional[Any]:
  129. """上传音频至 Gemini,返回文件对象"""
  130. self._ensure_configured()
  131. max_retries = 3
  132. retry_delay = 5
  133. for attempt in range(max_retries):
  134. try:
  135. if not os.path.exists(audio_path):
  136. logger.error('错误: 文件不存在')
  137. return None
  138. file_size = os.path.getsize(audio_path)
  139. if file_size == 0:
  140. logger.error('错误: 文件大小为0')
  141. return None
  142. try:
  143. with open(audio_path, 'rb') as f:
  144. f.read(1024)
  145. except Exception as e:
  146. logger.error(f'错误: 文件无法读取 - {e}')
  147. return None
  148. try:
  149. # 使用常见音频 MIME 类型。若后续需要可根据扩展名判断
  150. audio_file = genai.upload_file(path=audio_path, mime_type='audio/mpeg')
  151. except Exception as e:
  152. msg = str(e)
  153. logger.error(f'错误: 文件上传请求失败 - {msg}')
  154. if any(k in msg.lower() for k in ['broken pipe', 'connection', 'timeout', 'network']):
  155. if attempt < max_retries - 1:
  156. time.sleep(retry_delay)
  157. retry_delay *= 2
  158. continue
  159. return None
  160. return None
  161. # 等待处理
  162. max_wait_time = 120
  163. waited = 0
  164. while getattr(audio_file, 'state', None) and getattr(audio_file.state, 'name', '') == 'PROCESSING' and waited < max_wait_time:
  165. time.sleep(2)
  166. waited += 2
  167. try:
  168. audio_file = genai.get_file(name=audio_file.name)
  169. if audio_file.state.name in ['FAILED', 'ERROR', 'INVALID']:
  170. if attempt < max_retries - 1:
  171. time.sleep(retry_delay)
  172. retry_delay *= 2
  173. break
  174. return None
  175. except Exception as e:
  176. logger.warning(f'获取文件状态失败: {e}')
  177. if waited <= 60:
  178. return None
  179. continue
  180. if getattr(audio_file, 'state', None) and audio_file.state.name == 'ACTIVE':
  181. logger.info(f'音频上传成功: {audio_file.name}')
  182. return audio_file
  183. else:
  184. if attempt < max_retries - 1:
  185. time.sleep(retry_delay)
  186. retry_delay *= 2
  187. continue
  188. return None
  189. except Exception as e:
  190. msg = str(e)
  191. if any(k in msg.lower() for k in ['broken pipe', 'connection', 'timeout', 'network']):
  192. if attempt < max_retries - 1:
  193. time.sleep(retry_delay)
  194. retry_delay *= 2
  195. continue
  196. return None
  197. logger.error(f'音频上传异常: {msg}')
  198. return None
  199. return None
  200. def extract_audio_urls(self, formatted_content: Dict[str, Any]) -> List[str]:
  201. """从 formatted_content 中提取音频 URL 列表
  202. 兼容以下结构:
  203. - audio_url_list: [{"audio_url": "..."}, ...]
  204. - voice_data: {"url": "..."} 或 [{"url": "..."}, ...]
  205. - bgm_data: {"url": "..."}
  206. """
  207. urls: List[str] = []
  208. # audio_url_list
  209. for item in (formatted_content.get('audio_url_list') or []):
  210. if isinstance(item, dict) and item.get('audio_url'):
  211. urls.append(item['audio_url'])
  212. elif isinstance(item, str):
  213. urls.append(item)
  214. # voice_data
  215. voice_data = formatted_content.get('voice_data')
  216. if isinstance(voice_data, dict) and voice_data.get('url'):
  217. urls.append(voice_data['url'])
  218. elif isinstance(voice_data, list):
  219. for v in voice_data:
  220. if isinstance(v, dict) and v.get('url'):
  221. urls.append(v['url'])
  222. elif isinstance(v, str):
  223. urls.append(v)
  224. # bgm_data
  225. bgm_data = formatted_content.get('bgm_data')
  226. if isinstance(bgm_data, dict) and bgm_data.get('url'):
  227. urls.append(bgm_data['url'])
  228. # 去重并保持顺序
  229. seen = set()
  230. deduped: List[str] = []
  231. for u in urls:
  232. if u and u not in seen:
  233. seen.add(u)
  234. deduped.append(u)
  235. return deduped
  236. def analyze_audios_with_gemini(self, audio_urls: List[str]) -> List[Dict[str, Any]]:
  237. """将音频上传到 Gemini 并进行转写,返回按输入顺序的结果列表"""
  238. if not audio_urls:
  239. return []
  240. results: List[Dict[str, Any]] = [{} for _ in range(len(audio_urls))]
  241. def process_one(idx_and_url) -> Dict[str, Any]:
  242. idx, url = idx_and_url
  243. audio_file = None
  244. local_path: Optional[str] = None
  245. try:
  246. self._ensure_configured()
  247. logger.info(f"配置Gemini: {self.api_key}")
  248. # 1. 下载
  249. local_path = self.download_audio(url)
  250. if not local_path:
  251. return {"url": url, "asr_content": "音频下载失败"}
  252. # 2. 上传
  253. audio_file = self.upload_audio_to_gemini(local_path)
  254. # 清理本地文件
  255. try:
  256. if local_path and os.path.exists(local_path):
  257. os.remove(local_path)
  258. except Exception:
  259. pass
  260. if not audio_file:
  261. return {"url": url, "asr_content": "音频上传失败"}
  262. # 3. 生成
  263. response = self.model.generate_content(
  264. contents=[self.system_prompt, audio_file],
  265. request_options={'timeout': 500}
  266. )
  267. # 尝试读取文本
  268. try:
  269. text_out = ''
  270. # 优先从 candidates 结构提取,避免某些情况下 .text 不可用
  271. candidates = getattr(response, 'candidates', None)
  272. if candidates and len(candidates) > 0:
  273. first = candidates[0]
  274. content = getattr(first, 'content', None)
  275. parts = getattr(content, 'parts', None) if content else None
  276. if parts and len(parts) > 0:
  277. part0 = parts[0]
  278. text_out = getattr(part0, 'text', None) if hasattr(part0, 'text') else part0.get('text') if isinstance(part0, dict) else ''
  279. if not text_out and hasattr(response, 'text') and isinstance(response.text, str):
  280. text_out = response.text
  281. text_out = (text_out or '').strip()
  282. if not text_out:
  283. return {"url": url, "asr_content": "ASR分析失败:无内容"}
  284. return {"url": url, "asr_content": text_out}
  285. except Exception as e:
  286. return {"url": url, "asr_content": f"ASR分析失败:{str(e)}"}
  287. except Exception as e:
  288. return {"url": url, "asr_content": f"处理失败: {str(e)}"}
  289. finally:
  290. # 4. 清理远端文件
  291. if audio_file and hasattr(audio_file, 'name'):
  292. try:
  293. genai.delete_file(name=audio_file.name)
  294. except Exception:
  295. pass
  296. # 顺序处理,保持简单稳妥
  297. for idx, url in enumerate(audio_urls):
  298. result = process_one((idx, url))
  299. results[idx] = result
  300. return results
  301. def process_audios(self, formatted_content: Dict[str, Any]) -> List[Dict[str, Any]]:
  302. """处理音频识别的主函数,返回 [{url, asr_content}]"""
  303. try:
  304. audio_urls = self.extract_audio_urls(formatted_content)
  305. if not audio_urls:
  306. return []
  307. return self.analyze_audios_with_gemini(audio_urls)
  308. finally:
  309. # 触发一次缓存清理(若到时间)
  310. self.cleanup_cache()
  311. def main():
  312. """测试函数"""
  313. test_content = {
  314. "audio_url_list": [
  315. {"audio_url": "http://rescdn.yishihui.com/pipeline/audio/09417cf6-60ec-4b62-8ee1-06f9268b13b1.mp3"}
  316. ]
  317. }
  318. identifier = AudioIdentifier()
  319. result = identifier.process_audios(test_content)
  320. print(json.dumps(result, ensure_ascii=False, indent=2))
  321. if __name__ == '__main__':
  322. main()