|
@@ -0,0 +1,361 @@
|
|
|
|
+#!/usr/bin/env python3
|
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
|
+"""
|
|
|
|
+音频识别脚本
|
|
|
|
+主要功能:将音频转文字(ASR),参考视频识别模块的结构实现
|
|
|
|
+支持从 formatted_content 中提取音频 URL,下载后上传至 Gemini 进行转写。
|
|
|
|
+"""
|
|
|
|
+
|
|
|
|
+import os
|
|
|
|
+import json
|
|
|
|
+import time
|
|
|
|
+import sys
|
|
|
|
+import uuid
|
|
|
|
+import requests
|
|
|
|
+from typing import Dict, Any, List, Optional
|
|
|
|
+from dotenv import load_dotenv
|
|
|
|
+
|
|
|
|
+# 导入自定义模块
|
|
|
|
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
|
|
+from utils.logging_config import get_logger
|
|
|
|
+
|
|
|
|
+# 创建 logger
|
|
|
|
+logger = get_logger('AudioIdentifier')
|
|
|
|
+
|
|
|
|
+# 导入Google Generative AI
|
|
|
|
+import google.generativeai as genai
|
|
|
|
+from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
|
|
|
+
|
|
|
|
+# 缓存目录配置
|
|
|
|
+CACHE_DIR = os.path.join(os.path.dirname(__file__), 'cache')
|
|
|
|
+# 缓存文件最大保留时间(秒)
|
|
|
|
+CACHE_MAX_AGE = 3600 # 1小时
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class AudioIdentifier:
|
|
|
|
+ def __init__(self):
|
|
|
|
+ # 加载环境变量
|
|
|
|
+ load_dotenv()
|
|
|
|
+
|
|
|
|
+ # 延迟配置Gemini,在真正使用时再设置
|
|
|
|
+ self._configured = False
|
|
|
|
+ self.model = None
|
|
|
|
+ self.api_key = None
|
|
|
|
+
|
|
|
|
+ # 初始化缓存清理时间
|
|
|
|
+ self.last_cache_cleanup = time.time()
|
|
|
|
+
|
|
|
|
+ # 系统提示词:仅做语音转文字
|
|
|
|
+ self.system_prompt = (
|
|
|
|
+ "你是一个专业的音频转写助手。请严格将音频中的语音内容转换为文字,不要添加任何分析、解释或评论。"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def _ensure_configured(self):
|
|
|
|
+ """确保Gemini已配置"""
|
|
|
|
+ if not self._configured:
|
|
|
|
+ # 与图片模块保持一致读取 GEMINI_API_KEY_1,若无则回退 GEMINI_API_KEY
|
|
|
|
+ self.api_key = os.getenv('GEMINI_API_KEY_1') or os.getenv('GEMINI_API_KEY')
|
|
|
|
+ if not self.api_key:
|
|
|
|
+ raise ValueError('请在环境变量中设置 GEMINI_API_KEY_1 或 GEMINI_API_KEY')
|
|
|
|
+ genai.configure(api_key=self.api_key)
|
|
|
|
+ # 使用通用多模态模型进行音频理解
|
|
|
|
+ self.model = genai.GenerativeModel(
|
|
|
|
+ model_name='gemini-2.5-flash',
|
|
|
|
+ generation_config=genai.GenerationConfig(
|
|
|
|
+ response_mime_type='text/plain',
|
|
|
|
+ temperature=0.2,
|
|
|
|
+ max_output_tokens=40960
|
|
|
|
+ ),
|
|
|
|
+ safety_settings={
|
|
|
|
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
|
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
|
|
|
+ }
|
|
|
|
+ )
|
|
|
|
+ self._configured = True
|
|
|
|
+
|
|
|
|
+ def cleanup_cache(self):
|
|
|
|
+ """清理过期的缓存文件"""
|
|
|
|
+ try:
|
|
|
|
+ current_time = time.time()
|
|
|
|
+ if current_time - self.last_cache_cleanup < 3600:
|
|
|
|
+ return
|
|
|
|
+ if not os.path.exists(CACHE_DIR):
|
|
|
|
+ return
|
|
|
|
+ cleaned_count = 0
|
|
|
|
+ for filename in os.listdir(CACHE_DIR):
|
|
|
|
+ file_path = os.path.join(CACHE_DIR, filename)
|
|
|
|
+ if os.path.isfile(file_path):
|
|
|
|
+ file_age = current_time - os.path.getmtime(file_path)
|
|
|
|
+ if file_age > CACHE_MAX_AGE:
|
|
|
|
+ try:
|
|
|
|
+ os.remove(file_path)
|
|
|
|
+ cleaned_count += 1
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.warning(f'清理缓存文件失败: {file_path}, 错误: {e}')
|
|
|
|
+ if cleaned_count > 0:
|
|
|
|
+ logger.info(f'已清理 {cleaned_count} 个过期缓存文件')
|
|
|
|
+ self.last_cache_cleanup = current_time
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f'清理缓存失败: {e}')
|
|
|
|
+
|
|
|
|
+ def download_audio(self, audio_url: str) -> Optional[str]:
|
|
|
|
+ """下载音频到本地缓存并返回路径"""
|
|
|
|
+ # 猜测常见音频类型,后续统一按 mp3 保存
|
|
|
|
+ file_path = os.path.join(CACHE_DIR, f'{str(uuid.uuid4())}.mp3')
|
|
|
|
+ try:
|
|
|
|
+ os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f'创建缓存目录失败: {e}')
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ for attempt in range(3):
|
|
|
|
+ try:
|
|
|
|
+ response = requests.get(url=audio_url, timeout=60)
|
|
|
|
+ if response.status_code == 200:
|
|
|
|
+ try:
|
|
|
|
+ with open(file_path, 'wb') as f:
|
|
|
|
+ f.write(response.content)
|
|
|
|
+ return file_path
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f'音频保存失败: {e}')
|
|
|
|
+ if os.path.exists(file_path):
|
|
|
|
+ try:
|
|
|
|
+ os.remove(file_path)
|
|
|
|
+ except Exception:
|
|
|
|
+ pass
|
|
|
|
+ return None
|
|
|
|
+ else:
|
|
|
|
+ logger.warning(f'音频下载失败,状态码: {response.status_code}')
|
|
|
|
+ if attempt == 2:
|
|
|
|
+ return None
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.warning(f'下载尝试 {attempt + 1} 失败: {e}')
|
|
|
|
+ if attempt < 2:
|
|
|
|
+ time.sleep(1)
|
|
|
|
+ continue
|
|
|
|
+ return None
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f'下载过程异常: {e}')
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ def upload_audio_to_gemini(self, audio_path: str) -> Optional[Any]:
|
|
|
|
+ """上传音频至 Gemini,返回文件对象"""
|
|
|
|
+ self._ensure_configured()
|
|
|
|
+ max_retries = 3
|
|
|
|
+ retry_delay = 5
|
|
|
|
+ for attempt in range(max_retries):
|
|
|
|
+ try:
|
|
|
|
+ if not os.path.exists(audio_path):
|
|
|
|
+ logger.error('错误: 文件不存在')
|
|
|
|
+ return None
|
|
|
|
+ file_size = os.path.getsize(audio_path)
|
|
|
|
+ if file_size == 0:
|
|
|
|
+ logger.error('错误: 文件大小为0')
|
|
|
|
+ return None
|
|
|
|
+ try:
|
|
|
|
+ with open(audio_path, 'rb') as f:
|
|
|
|
+ f.read(1024)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f'错误: 文件无法读取 - {e}')
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ # 使用常见音频 MIME 类型。若后续需要可根据扩展名判断
|
|
|
|
+ audio_file = genai.upload_file(path=audio_path, mime_type='audio/mpeg')
|
|
|
|
+ except Exception as e:
|
|
|
|
+ msg = str(e)
|
|
|
|
+ logger.error(f'错误: 文件上传请求失败 - {msg}')
|
|
|
|
+ if any(k in msg.lower() for k in ['broken pipe', 'connection', 'timeout', 'network']):
|
|
|
|
+ if attempt < max_retries - 1:
|
|
|
|
+ time.sleep(retry_delay)
|
|
|
|
+ retry_delay *= 2
|
|
|
|
+ continue
|
|
|
|
+ return None
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ # 等待处理
|
|
|
|
+ max_wait_time = 120
|
|
|
|
+ waited = 0
|
|
|
|
+ while getattr(audio_file, 'state', None) and getattr(audio_file.state, 'name', '') == 'PROCESSING' and waited < max_wait_time:
|
|
|
|
+ time.sleep(2)
|
|
|
|
+ waited += 2
|
|
|
|
+ try:
|
|
|
|
+ audio_file = genai.get_file(name=audio_file.name)
|
|
|
|
+ if audio_file.state.name in ['FAILED', 'ERROR', 'INVALID']:
|
|
|
|
+ if attempt < max_retries - 1:
|
|
|
|
+ time.sleep(retry_delay)
|
|
|
|
+ retry_delay *= 2
|
|
|
|
+ break
|
|
|
|
+ return None
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.warning(f'获取文件状态失败: {e}')
|
|
|
|
+ if waited <= 60:
|
|
|
|
+ return None
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ if getattr(audio_file, 'state', None) and audio_file.state.name == 'ACTIVE':
|
|
|
|
+ logger.info(f'音频上传成功: {audio_file.name}')
|
|
|
|
+ return audio_file
|
|
|
|
+ else:
|
|
|
|
+ if attempt < max_retries - 1:
|
|
|
|
+ time.sleep(retry_delay)
|
|
|
|
+ retry_delay *= 2
|
|
|
|
+ continue
|
|
|
|
+ return None
|
|
|
|
+ except Exception as e:
|
|
|
|
+ msg = str(e)
|
|
|
|
+ if any(k in msg.lower() for k in ['broken pipe', 'connection', 'timeout', 'network']):
|
|
|
|
+ if attempt < max_retries - 1:
|
|
|
|
+ time.sleep(retry_delay)
|
|
|
|
+ retry_delay *= 2
|
|
|
|
+ continue
|
|
|
|
+ return None
|
|
|
|
+ logger.error(f'音频上传异常: {msg}')
|
|
|
|
+ return None
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ def extract_audio_urls(self, formatted_content: Dict[str, Any]) -> List[str]:
|
|
|
|
+ """从 formatted_content 中提取音频 URL 列表
|
|
|
|
+ 兼容以下结构:
|
|
|
|
+ - audio_url_list: [{"audio_url": "..."}, ...]
|
|
|
|
+ - voice_data: {"url": "..."} 或 [{"url": "..."}, ...]
|
|
|
|
+ - bgm_data: {"url": "..."}
|
|
|
|
+ """
|
|
|
|
+ urls: List[str] = []
|
|
|
|
+ # audio_url_list
|
|
|
|
+ for item in (formatted_content.get('audio_url_list') or []):
|
|
|
|
+ if isinstance(item, dict) and item.get('audio_url'):
|
|
|
|
+ urls.append(item['audio_url'])
|
|
|
|
+ elif isinstance(item, str):
|
|
|
|
+ urls.append(item)
|
|
|
|
+ # voice_data
|
|
|
|
+ voice_data = formatted_content.get('voice_data')
|
|
|
|
+ if isinstance(voice_data, dict) and voice_data.get('url'):
|
|
|
|
+ urls.append(voice_data['url'])
|
|
|
|
+ elif isinstance(voice_data, list):
|
|
|
|
+ for v in voice_data:
|
|
|
|
+ if isinstance(v, dict) and v.get('url'):
|
|
|
|
+ urls.append(v['url'])
|
|
|
|
+ elif isinstance(v, str):
|
|
|
|
+ urls.append(v)
|
|
|
|
+ # bgm_data
|
|
|
|
+ bgm_data = formatted_content.get('bgm_data')
|
|
|
|
+ if isinstance(bgm_data, dict) and bgm_data.get('url'):
|
|
|
|
+ urls.append(bgm_data['url'])
|
|
|
|
+
|
|
|
|
+ # 去重并保持顺序
|
|
|
|
+ seen = set()
|
|
|
|
+ deduped: List[str] = []
|
|
|
|
+ for u in urls:
|
|
|
|
+ if u and u not in seen:
|
|
|
|
+ seen.add(u)
|
|
|
|
+ deduped.append(u)
|
|
|
|
+ return deduped
|
|
|
|
+
|
|
|
|
+ def analyze_audios_with_gemini(self, audio_urls: List[str]) -> List[Dict[str, Any]]:
|
|
|
|
+ """将音频上传到 Gemini 并进行转写,返回按输入顺序的结果列表"""
|
|
|
|
+ if not audio_urls:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ results: List[Dict[str, Any]] = [{} for _ in range(len(audio_urls))]
|
|
|
|
+
|
|
|
|
+ def process_one(idx_and_url) -> Dict[str, Any]:
|
|
|
|
+ idx, url = idx_and_url
|
|
|
|
+ audio_file = None
|
|
|
|
+ local_path: Optional[str] = None
|
|
|
|
+ try:
|
|
|
|
+ self._ensure_configured()
|
|
|
|
+ logger.info(f"配置Gemini: {self.api_key}")
|
|
|
|
+
|
|
|
|
+ # 1. 下载
|
|
|
|
+ local_path = self.download_audio(url)
|
|
|
|
+ if not local_path:
|
|
|
|
+ return {"url": url, "asr_content": "音频下载失败"}
|
|
|
|
+
|
|
|
|
+ # 2. 上传
|
|
|
|
+ audio_file = self.upload_audio_to_gemini(local_path)
|
|
|
|
+
|
|
|
|
+ # 清理本地文件
|
|
|
|
+ try:
|
|
|
|
+ if local_path and os.path.exists(local_path):
|
|
|
|
+ os.remove(local_path)
|
|
|
|
+ except Exception:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ if not audio_file:
|
|
|
|
+ return {"url": url, "asr_content": "音频上传失败"}
|
|
|
|
+
|
|
|
|
+ # 3. 生成
|
|
|
|
+ response = self.model.generate_content(
|
|
|
|
+ contents=[self.system_prompt, audio_file],
|
|
|
|
+ request_options={'timeout': 500}
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 尝试读取文本
|
|
|
|
+ try:
|
|
|
|
+ text_out = ''
|
|
|
|
+ # 优先从 candidates 结构提取,避免某些情况下 .text 不可用
|
|
|
|
+ candidates = getattr(response, 'candidates', None)
|
|
|
|
+ if candidates and len(candidates) > 0:
|
|
|
|
+ first = candidates[0]
|
|
|
|
+ content = getattr(first, 'content', None)
|
|
|
|
+ parts = getattr(content, 'parts', None) if content else None
|
|
|
|
+ if parts and len(parts) > 0:
|
|
|
|
+ part0 = parts[0]
|
|
|
|
+ text_out = getattr(part0, 'text', None) if hasattr(part0, 'text') else part0.get('text') if isinstance(part0, dict) else ''
|
|
|
|
+ if not text_out and hasattr(response, 'text') and isinstance(response.text, str):
|
|
|
|
+ text_out = response.text
|
|
|
|
+ text_out = (text_out or '').strip()
|
|
|
|
+ if not text_out:
|
|
|
|
+ return {"url": url, "asr_content": "ASR分析失败:无内容"}
|
|
|
|
+ return {"url": url, "asr_content": text_out}
|
|
|
|
+ except Exception as e:
|
|
|
|
+ return {"url": url, "asr_content": f"ASR分析失败:{str(e)}"}
|
|
|
|
+ except Exception as e:
|
|
|
|
+ return {"url": url, "asr_content": f"处理失败: {str(e)}"}
|
|
|
|
+ finally:
|
|
|
|
+ # 4. 清理远端文件
|
|
|
|
+ if audio_file and hasattr(audio_file, 'name'):
|
|
|
|
+ try:
|
|
|
|
+ genai.delete_file(name=audio_file.name)
|
|
|
|
+ except Exception:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ # 顺序处理,保持简单稳妥
|
|
|
|
+ for idx, url in enumerate(audio_urls):
|
|
|
|
+ result = process_one((idx, url))
|
|
|
|
+ results[idx] = result
|
|
|
|
+
|
|
|
|
+ return results
|
|
|
|
+
|
|
|
|
+ def process_audios(self, formatted_content: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
|
+ """处理音频识别的主函数,返回 [{url, asr_content}]"""
|
|
|
|
+ try:
|
|
|
|
+ audio_urls = self.extract_audio_urls(formatted_content)
|
|
|
|
+ if not audio_urls:
|
|
|
|
+ return []
|
|
|
|
+ return self.analyze_audios_with_gemini(audio_urls)
|
|
|
|
+ finally:
|
|
|
|
+ # 触发一次缓存清理(若到时间)
|
|
|
|
+ self.cleanup_cache()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def main():
|
|
|
|
+ """测试函数"""
|
|
|
|
+ test_content = {
|
|
|
|
+ "audio_url_list": [
|
|
|
|
+ {"audio_url": "http://rescdn.yishihui.com/pipeline/audio/09417cf6-60ec-4b62-8ee1-06f9268b13b1.mp3"}
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
|
|
+ identifier = AudioIdentifier()
|
|
|
|
+ result = identifier.process_audios(test_content)
|
|
|
|
+ print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ main()
|
|
|
|
+
|
|
|
|
+
|