""" 脚本理解测试脚本 功能: 1. 从指定目录读取最新的 result_XXX.json 文件 2. 提取选题描述和帖子内容 3. 运行 ScriptUnderstandingAgent(step1 + step2) 4. 保存结果到 script_result_xxx_xxx.json """ import json import sys import os import argparse import time from pathlib import Path from datetime import datetime # 添加项目根目录到路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) # 手动加载.env文件 def load_env_file(env_path): """手动加载.env文件""" if not env_path.exists(): return False with open(env_path, 'r') as f: for line in f: line = line.strip() # 跳过注释和空行 if not line or line.startswith('#'): continue # 解析KEY=VALUE if '=' in line: key, value = line.split('=', 1) os.environ[key.strip()] = value.strip() return True env_path = project_root / ".env" if load_env_file(env_path): print(f"✅ 已加载环境变量从: {env_path}") # 验证API密钥 api_key = os.environ.get("GEMINI_API_KEY", "") if api_key: print(f" GEMINI_API_KEY: {api_key[:10]}...") else: print(f"⚠️ 未找到.env文件: {env_path}") from src.components.agents.script_section_division_agent import ScriptSectionDivisionAgent from src.components.agents.script_element_extraction_agent import ScriptElementExtractionAgent from src.utils.logger import get_logger from src.utils.llm_invoker import LLMInvoker import requests import tempfile import os from urllib3.exceptions import IncompleteRead logger = get_logger(__name__) def find_latest_result_file(directory): """ 查找指定目录中最新的 result_XXX.json 文件 Args: directory: 帖子目录名(如"阿里多多酱"或"G88818") Returns: Path: 最新result文件的路径,如果找不到则返回None """ output_dir = Path(__file__).parent / directory / "output" if not output_dir.exists(): print(f"⚠️ 输出目录不存在: {output_dir}") return None # 查找所有result_*.json文件 result_files = list(output_dir.glob("result_*.json")) if not result_files: print(f"⚠️ 未找到result_*.json文件") return None # 按修改时间排序,取最新的 latest_file = max(result_files, key=lambda p: p.stat().st_mtime) return latest_file def find_post_file(directory): """ 查找指定目录中的视频详情.json文件 Args: directory: 视频目录名(如"56898272") Returns: Path: 视频详情文件的路径,如果找不到则返回None """ post_file = Path(__file__).parent / directory / "视频详情.json" if not post_file.exists(): print(f"⚠️ 视频详情文件不存在: {post_file}") return None return post_file def load_result_file(file_path): """ 加载result文件 Args: file_path: result文件路径 Returns: dict: 解析后的JSON数据 """ with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) return data def extract_topic_description(result_data): """ 从result数据中提取选题描述 Args: result_data: result.json的数据 Returns: dict: 选题描述字典 """ topic_understanding = result_data.get("选题理解", {}) # 返回结构化的选题描述 return { "主题": topic_understanding.get("主题", ""), "描述": topic_understanding.get("描述", "") } def infer_content_category(result_data, post_data): """ 从result数据和帖子数据中推断内容品类 Args: result_data: result.json的数据 post_data: 待解构帖子.json的数据 Returns: str: 内容品类 """ # 尝试从选题理解中推断 topic_understanding = result_data.get("选题理解", {}) theme = topic_understanding.get("主题", "") description = topic_understanding.get("描述", "") # 基于关键词推断品类 content = f"{theme} {description} {post_data.get('title', '')} {post_data.get('body_text', '')}" content_lower = content.lower() # 常见品类关键词映射 category_keywords = { "美妆教程": ["化妆", "眼妆", "底妆", "口红", "粉底"], "美甲分享": ["美甲", "指甲", "甲油", "美甲设计"], "美食教程": ["食谱", "做菜", "烹饪", "美食", "制作"], "穿搭分享": ["穿搭", "搭配", "outfit", "服装", "衣服"], "旅行vlog": ["旅行", "旅游", "打卡", "游玩", "景点"], "健身教程": ["健身", "运动", "锻炼", "瑜伽", "训练"], "手工DIY": ["手工", "diy", "制作", "手作"], "护肤分享": ["护肤", "面膜", "精华", "皮肤"], "摄影分享": ["摄影", "拍照", "相机", "照片"], } # 匹配品类 for category, keywords in category_keywords.items(): for keyword in keywords: if keyword in content_lower or keyword in content: return category # 如果没有匹配到,使用通用描述 return "创意分享" def extract_post_content(post_data): """ 从视频详情数据中提取视频内容,并移除所有话题标签 Args: post_data: 视频详情.json的数据 Returns: tuple: (text_data, video_url) """ import re # 提取原始数据 title = post_data.get("title", "") body = post_data.get("body_text", "") # 移除body中的所有话题标签(格式:#xxx[话题]# 或 #xxx#) # 匹配模式:# 开头,后面是任意字符,可能包含[话题],以 # 结尾 body_cleaned = re.sub(r'#[^#]+?(?:\[话题\])?\s*#', '', body) # 清理多余的空白字符 body_cleaned = re.sub(r'\s+', ' ', body_cleaned).strip() text_data = { "title": title, "body": body_cleaned } video_url = post_data.get("video", "") return text_data, video_url def download_and_upload_video(video_url: str, directory: str): """ 下载视频并上传到Gemini Args: video_url: 视频URL directory: 目录名(用于查找本地文件) Returns: Gemini文件对象,失败返回 None """ if not video_url: print("⚠️ 未提供视频URL,跳过上传") return None try: # 1. 首先检查examples目录下是否有对应的mp4文件 examples_dir = Path(__file__).parent local_video_path = examples_dir / directory / f"{directory}.mp4" if local_video_path.exists() and local_video_path.is_file(): print(f"✅ 在examples目录下找到现有文件: {local_video_path.name}") video_file_path = str(local_video_path) is_temp_file = False else: # 2. 如果没有找到,则下载到临时文件 print(f"📥 开始下载视频: {video_url}") # 创建临时文件 temp_file = tempfile.NamedTemporaryFile( suffix=".mp4", delete=False ) temp_file_path = temp_file.name temp_file.close() # 下载视频(带重试机制) max_retries = 3 retry_count = 0 last_exception = None video_file_path = None is_temp_file = True while retry_count < max_retries: try: if retry_count > 0: print(f"🔄 重试下载视频 (第 {retry_count}/{max_retries-1} 次)...") # 使用 Session 进行下载 session = requests.Session() session.headers.update({ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" }) # 下载视频(增加超时时间) response = session.get( video_url, timeout=(30, 120), # (连接超时, 读取超时) stream=True ) response.raise_for_status() # 确保目录存在 os.makedirs(os.path.dirname(temp_file_path), exist_ok=True) # 写入文件 with open(temp_file_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): if chunk: f.write(chunk) # 验证文件大小 file_size = os.path.getsize(temp_file_path) if file_size == 0: raise ValueError("下载的文件大小为0") print(f"✅ 视频下载完成,大小: {file_size / 1024 / 1024:.2f} MB") video_file_path = temp_file_path break # 下载成功,退出重试循环 except (requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.RequestException, ConnectionError, IncompleteRead) as e: last_exception = e retry_count += 1 # 清理不完整的文件 if os.path.exists(temp_file_path): try: os.remove(temp_file_path) except: pass if retry_count < max_retries: wait_time = retry_count * 2 # 递增等待时间:2秒、4秒 print(f"⚠️ 下载失败 (尝试 {retry_count}/{max_retries}): {e}") print(f" 等待 {wait_time} 秒后重试...") time.sleep(wait_time) else: print(f"❌ 下载失败,已重试 {max_retries} 次") raise except Exception as e: # 其他类型的异常直接抛出,不重试 if os.path.exists(temp_file_path): try: os.remove(temp_file_path) except: pass raise # 如果所有重试都失败了 if not video_file_path: if last_exception: raise last_exception else: raise Exception("视频下载失败") # 3. 上传视频到Gemini print(f"📤 上传视频到Gemini...") video_file = LLMInvoker.upload_video_to_gemini(video_file_path) # 4. 清理临时文件 if is_temp_file: try: os.remove(video_file_path) print(f"✅ 临时文件已删除") except Exception as e: print(f"⚠️ 删除临时文件失败: {e}") if not video_file: print(f"❌ 视频上传到Gemini失败") return None # 5. 获取文件信息(用于日志) file_name = None if hasattr(video_file, 'name'): file_name = video_file.name elif hasattr(video_file, 'uri'): # 从URI中提取文件名 file_uri = video_file.uri if "/files/" in file_uri: file_name = file_uri.split("/files/")[-1] print(f"✅ 视频上传成功") if file_name: print(f" 文件名称: {file_name}") # 直接返回文件对象 return video_file except Exception as e: print(f"❌ 视频下载/上传失败: {e}") import traceback traceback.print_exc() return None def main(): """主函数""" # 解析命令行参数 parser = argparse.ArgumentParser(description='运行脚本理解Agent(视频分析版本)') parser.add_argument('directory', type=str, help='视频目录名(如"56898272"),目录下需要有"视频详情.json"文件') args = parser.parse_args() directory = args.directory print("=" * 80) print(f"开始运行脚本理解Agent - 目录: {directory}") print("=" * 80) # 1. 查找视频详情文件 print("\n[1] 查找视频详情文件...") try: post_file = find_post_file(directory) if not post_file: print(f"❌ 未找到视频详情文件") return print(f"✅ 找到视频详情文件: {post_file.name}") print(f" 文件路径: {post_file}") except Exception as e: print(f"❌ 查找视频详情文件失败: {e}") return # 2. 加载视频详情文件 print("\n[2] 加载视频详情文件...") try: post_data = load_result_file(post_file) print(f"✅ 成功加载视频详情文件") except Exception as e: print(f"❌ 加载视频详情文件失败: {e}") return # 3. 提取视频内容 print("\n[3] 提取视频内容...") try: text_data, video_url = extract_post_content(post_data) print(f"✅ 成功提取视频内容") print(f" 标题: {text_data.get('title', '无')}") print(f" 正文长度: {len(text_data.get('body', ''))}") print(f" 视频URL: {'有' if video_url else '无'}") except Exception as e: print(f"❌ 提取视频内容失败: {e}") return # 4. 查找最新的result文件 print("\n[4] 查找最新的result文件...") try: result_file = find_latest_result_file(directory) if not result_file: print(f"❌ 未找到result文件") return print(f"✅ 找到最新result文件: {result_file.name}") print(f" 文件路径: {result_file}") print(f" 修改时间: {datetime.fromtimestamp(result_file.stat().st_mtime)}") except Exception as e: print(f"❌ 查找result文件失败: {e}") return # 5. 加载result文件 print("\n[5] 加载result文件...") try: result_data = load_result_file(result_file) print(f"✅ 成功加载result文件") except Exception as e: print(f"❌ 加载result文件失败: {e}") return # 6. 提取选题描述 print("\n[6] 提取选题描述...") try: topic_description = extract_topic_description(result_data) print(f"✅ 成功提取选题描述") print(f" 选题描述:") if topic_description.get("主题"): print(f" 主题: {topic_description['主题']}") if topic_description.get("描述"): print(f" 描述: {topic_description['描述']}") except Exception as e: print(f"❌ 提取选题描述失败: {e}") return # 7. 下载并上传视频到Gemini print("\n[7] 下载并上传视频到Gemini...") video_file = None if video_url: try: video_file = download_and_upload_video(video_url, directory) if not video_file: print(f"⚠️ 视频上传失败,但继续执行(可能影响视频分析功能)") except Exception as e: print(f"⚠️ 视频上传失败: {e},但继续执行(可能影响视频分析功能)") import traceback traceback.print_exc() else: print(f"⚠️ 未提供视频URL,跳过上传") # 8. 初始化两个Agent print("\n[8] 初始化ScriptSectionDivisionAgent和ScriptElementExtractionAgent...") try: section_agent = ScriptSectionDivisionAgent( model_provider="google_genai" ) element_agent = ScriptElementExtractionAgent( model_provider="google_genai" ) print(f"✅ Agent初始化成功") except Exception as e: print(f"❌ Agent初始化失败: {e}") import traceback traceback.print_exc() return # 9. 组装state对象 print("\n[9] 组装state对象...") try: # 构建选题理解格式(模拟workflow中的格式) topic_understanding = result_data.get("选题理解", {}) state = { "text": text_data, "video": video_url, "topic_selection_understanding": topic_understanding } # 添加视频文件对象(如果上传成功) if video_file: state["video_file"] = video_file print(f"✅ State对象组装成功") print(f" - 文本: {bool(text_data)}") print(f" - 视频URL: {'有' if video_url else '无'}") print(f" - 视频文件对象: {'有' if video_file else '无'}") print(f" - 选题理解: {bool(topic_understanding)}") except Exception as e: print(f"❌ 组装state对象失败: {e}") return # 10. 执行两个Agent print("\n[10] 执行脚本段落划分Agent...") try: section_result = section_agent.process(state) sections = section_result.get("段落列表", []) content_category = section_result.get("内容品类", "未知品类") print(f"✅ 段落划分执行成功") print(f" 内容品类: {content_category}") print(f" 划分出 {len(sections)} 个Section") except Exception as e: print(f"❌ 段落划分执行失败: {e}") import traceback traceback.print_exc() return print("\n[11] 执行脚本元素提取Agent...") try: # 更新state,添加段落划分结果和其他必需数据 state["section_division"] = {"段落列表": sections} # 从result_data中提取灵感点、目的点、关键点(从"三点解构"中提取) three_points = result_data.get("三点解构", {}) state["inspiration_points"] = three_points.get("灵感点", {}) state["purpose_points"] = three_points.get("目的点", {}) state["key_points"] = three_points.get("关键点", {}) element_result = element_agent.process(state) elements = element_result.get("元素列表", []) tendency_judgment = element_result.get("视频倾向判断", {}) print(f"✅ 元素提取执行成功") print(f" 识别出 {len(elements)} 个元素") if tendency_judgment: print(f" 视频倾向: {tendency_judgment.get('判断结果', '未知')}") except Exception as e: print(f"❌ 元素提取执行失败: {e}") import traceback traceback.print_exc() return # 12. 组装最终结果 print("\n[12] 组装最终结果...") try: # 递归统计数量 def count_items(items_list): count = len(items_list) for item in items_list: if item.get('子项'): count += count_items(item['子项']) return count total_sections = count_items(sections) total_elements = count_items(elements) # 组装脚本理解结果 script_understanding = { "内容品类": content_category, "段落列表": sections, "元素列表": elements, "视频URL": video_url, "视频倾向判断": tendency_judgment # 添加视频倾向判断 } final_result = { "选题描述": topic_description, "脚本理解": script_understanding, "元信息": { "段落总数": total_sections, "元素总数": total_elements, "来源帖子文件": post_file.name, "来源结果文件": result_file.name, "执行时间": datetime.now().isoformat() } } print(f"✅ 结果组装成功") except Exception as e: print(f"❌ 结果组装失败: {e}") return # 13. 保存结果 print("\n[13] 保存结果...") try: # 生成带时间戳的文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_filename = f"script_result_{timestamp}.json" output_path = Path(__file__).parent / directory / "output" / output_filename output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: json.dump(final_result, f, ensure_ascii=False, indent=2) print(f"✅ 结果已保存到: {output_path}") print(f" 文件名: {output_filename}") except Exception as e: print(f"❌ 保存结果失败: {e}") return # 13. 显示结果摘要 print("\n" + "=" * 80) print("结果摘要") print("=" * 80) print(f"\n选题描述:") if topic_description.get("主题"): print(f" 主题: {topic_description['主题']}") if topic_description.get("描述"): print(f" 描述: {topic_description['描述']}") # 递归打印Section树状结构 def print_sections(sections_list, indent=0): for idx, section in enumerate(sections_list, 1): prefix = " " + " " * indent print(f"{prefix}{idx}. {section.get('描述', 'N/A')}") if section.get('子项'): print_sections(section['子项'], indent + 1) # 递归统计Section数量(只统计叶子节点) def count_sections(sections_list): count = 0 for section in sections_list: if section.get('子项'): # 有子项,递归统计子项 count += count_sections(section['子项']) else: # 无子项,是叶子节点 count += 1 return count total_sections = count_sections(sections) print(f"\nSection列表 ({total_sections} 个):") print_sections(sections) # 打印Element列表(只打印名称和类型,不打印树状结构) def print_elements(elements_list): for element in elements_list: name = element.get('名称', 'N/A') elem_type = element.get('类型', 'N/A') classification = element.get('分类', {}) # 构建分类路径 if classification: class_path = " > ".join([v for v in classification.values() if v]) print(f" - [{elem_type}] {name} ({class_path})") else: print(f" - [{elem_type}] {name}") # 不再递归统计,直接使用列表长度 total_elements = len(elements) print(f"\n元素列表 ({total_elements} 个):") if elements: print_elements(elements) else: print(" (无)") print("\n" + "=" * 80) print("测试完成!") print("=" * 80) if __name__ == "__main__": main()