| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- """
- Script Workflow.
- 脚本理解工作流:编排脚本段落划分和元素提取流程的执行顺序和流程逻辑。
- 流程:段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总
- """
- from typing import Dict, Any
- from langgraph.graph import StateGraph, END
- from src.components.agents.base import BaseGraphAgent
- from src.states.script_state import ScriptState
- from src.components.agents.script_section_division_agent import ScriptSectionDivisionAgent
- from src.components.agents.script_substance_extraction_agent import ScriptSubstanceExtractionAgent
- from src.components.agents.script_form_extraction_agent import ScriptFormExtractionAgent
- from src.components.functions.video_upload_function import VideoUploadFunction
- from src.utils.logger import get_logger
- logger = get_logger(__name__)
- class ScriptWorkflow(BaseGraphAgent):
- """脚本理解工作流
- 功能:
- - 编排脚本理解流程
- - 流程:段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总
- - 管理状态传递
- 实现方式:BaseGraphAgent (LangGraph)
- """
- def __init__(
- self,
- name: str = "script_workflow",
- description: str = "脚本理解工作流",
- model_provider: str = "google_genai"
- ):
- super().__init__(
- name=name,
- description=description,
- state_class=ScriptState
- )
- self.model_provider = model_provider
- # 初始化视频上传Function
- self.video_upload_func = VideoUploadFunction()
- # 初始化脚本段落划分Agent
- self.section_agent = ScriptSectionDivisionAgent(
- model_provider=model_provider
- )
- # 初始化实质提取Agent
- self.substance_agent = ScriptSubstanceExtractionAgent(
- model_provider=model_provider
- )
- # 初始化形式提取Agent
- self.form_agent = ScriptFormExtractionAgent(
- model_provider=model_provider
- )
- logger.info(f"ScriptWorkflow 初始化完成,model_provider: {model_provider}")
- def _build_graph(self) -> StateGraph:
- """构建工作流图(视频分析版)
- 流程:
- START → 视频上传 → 段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总 → END
- """
- workflow = StateGraph(dict) # 使用dict作为状态类型
- # 添加所有节点
- workflow.add_node("video_upload", self._video_upload_node)
- workflow.add_node("section_division", self._section_division_node)
- workflow.add_node("substance_extraction", self._substance_extraction_node)
- workflow.add_node("form_extraction", self._form_extraction_node)
- workflow.add_node("merge_all_results", self._merge_all_results_node)
- workflow.add_node("result_aggregation", self._result_aggregation_node)
- # 定义流程的边
- workflow.set_entry_point("video_upload")
- workflow.add_edge("video_upload", "section_division")
- workflow.add_edge("section_division", "substance_extraction")
- workflow.add_edge("substance_extraction", "form_extraction")
- workflow.add_edge("form_extraction", "merge_all_results")
- workflow.add_edge("merge_all_results", "result_aggregation")
- workflow.add_edge("result_aggregation", END)
- logger.info("工作流图构建完成 - 流程:视频上传 → 段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总")
- return workflow
- def _video_upload_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """节点:视频上传(第一步)- 下载视频并上传至Gemini"""
- logger.info("=== 执行节点:视频上传 ===")
- try:
- # 初始化Function
- if not self.video_upload_func.is_initialized:
- self.video_upload_func.initialize()
- # 执行视频上传
- result = self.video_upload_func.execute(state)
- # 更新状态
- state.update(result)
- video_uri = result.get("video_uploaded_uri")
- if video_uri:
- logger.info(f"视频上传完成 - URI: {video_uri}")
- else:
- error = result.get("video_upload_error", "未知错误")
- logger.warning(f"视频上传失败: {error}")
- except Exception as e:
- logger.error(f"视频上传失败: {e}", exc_info=True)
- state.update({
- "video_uploaded_uri": None,
- "video_upload_error": str(e)
- })
- return state
- def _section_division_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """节点:脚本段落划分"""
- logger.info("=== 执行节点:脚本段落划分 ===")
- try:
- # 初始化Agent
- if not self.section_agent.is_initialized:
- self.section_agent.initialize()
- # 执行Agent
- result = self.section_agent.process(state)
- # 更新状态
- state.update(result)
- sections = result.get("段落列表", [])
- content_category = result.get("内容品类", "未知")
- logger.info(f"脚本段落划分完成 - 内容品类: {content_category}, 段落数: {len(sections)}")
- except Exception as e:
- logger.error(f"脚本段落划分失败: {e}", exc_info=True)
- state.update({
- "内容品类": "未知品类",
- "段落列表": []
- })
- return state
- def _substance_extraction_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """节点:实质元素提取"""
- logger.info("=== 执行节点:实质元素提取 ===")
- try:
- # 初始化Agent
- if not self.substance_agent.is_initialized:
- self.substance_agent.initialize()
- # 准备状态:将段落列表包装到section_division字段中
- sections = state.get("段落列表", [])
- state["section_division"] = {"段落列表": sections}
- # 执行Agent
- result = self.substance_agent.process(state)
- # 更新状态
- state.update(result)
- final_elements = result.get("substance_final_elements", [])
- logger.info(f"实质元素提取完成 - 最终元素数: {len(final_elements)}")
- except Exception as e:
- logger.error(f"实质元素提取失败: {e}", exc_info=True)
- state.update({
- "concrete_elements": [],
- "concrete_concepts": [],
- "abstract_concepts": [],
- "substance_elements": [],
- "substance_analyzed_result": [],
- "substance_scored_result": {},
- "substance_filtered_ids": [],
- "substance_categorized_result": {},
- "substance_final_elements": []
- })
- return state
- def _form_extraction_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """节点:形式元素提取"""
- logger.info("=== 执行节点:形式元素提取 ===")
- try:
- # 初始化Agent
- if not self.form_agent.is_initialized:
- self.form_agent.initialize()
- # 执行Agent(依赖实质元素)
- result = self.form_agent.process(state)
- # 更新状态
- state.update(result)
- final_elements = result.get("form_final_elements", [])
- logger.info(f"形式元素提取完成 - 最终元素数: {len(final_elements)}")
- except Exception as e:
- logger.error(f"形式元素提取失败: {e}", exc_info=True)
- state.update({
- "concrete_element_forms": [],
- "concrete_concept_forms": [],
- "overall_forms": [],
- "form_elements": [],
- "form_analyzed_result": [],
- "form_scored_result": {},
- "form_weighted_result": {},
- "form_filtered_ids": [],
- "form_categorized_result": {},
- "form_final_elements": []
- })
- return state
- def _merge_all_results_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """节点:分离实质和形式结果(Step 7)"""
- logger.info("=== 执行节点:分离实质和形式结果 ===")
- try:
- # 获取实质和形式的最终元素
- substance_final_elements = state.get("substance_final_elements", [])
- form_final_elements = state.get("form_final_elements", [])
- # 分别存储实质列表和形式列表
- state["实质列表"] = substance_final_elements
- state["形式列表"] = form_final_elements
- logger.info(f"分离完成 - 实质元素: {len(substance_final_elements)}, 形式元素: {len(form_final_elements)}")
- except Exception as e:
- logger.error(f"分离结果失败: {e}", exc_info=True)
- state["实质列表"] = []
- state["形式列表"] = []
- return state
- def _result_aggregation_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
- """节点:结果汇总"""
- logger.info("=== 执行节点:结果汇总 ===")
- try:
- # 从 topic_selection_understanding 提取选题描述
- topic_understanding = state.get("topic_selection_understanding", {})
- # 兼容两种结构:直接包含主题/描述,或嵌套在"选题"键下
- if "选题" in topic_understanding:
- selected_topic = topic_understanding.get("选题", {})
- else:
- selected_topic = topic_understanding
- # 组装最终结果 - 实质和形式分别输出
- final_result = {
- "选题描述": {
- "主题": selected_topic.get("主题", ""),
- "描述": selected_topic.get("描述", "")
- },
- "脚本理解": {
- "内容品类": state.get("内容品类", "未知"),
- "段落列表": state.get("段落列表", []),
- "实质列表": state.get("实质列表", []), # 独立的实质列表
- "形式列表": state.get("形式列表", []), # 独立的形式列表
- "图片列表": state.get("images", [])
- },
- "灵感点": state.get("inspiration_points", []),
- "目的点": state.get("purpose_points", []),
- "关键点": state.get("key_points", [])
- }
- # 更新状态
- state["final_result"] = final_result
- logger.info("结果汇总完成")
- except Exception as e:
- logger.error(f"结果汇总失败: {e}", exc_info=True)
- state["final_result"] = {
- "错误": f"汇总失败: {str(e)}"
- }
- return state
- def invoke(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
- """执行工作流(公共接口)- 视频分析版
- Returns:
- 最终脚本理解结果
- """
- logger.info("=== 开始执行脚本理解工作流(视频分析) ===")
- # 确保工作流已初始化
- if not self.is_initialized:
- self.initialize()
- # 构建 text(兼容两种输入方式)
- if "text" in input_data and isinstance(input_data.get("text"), dict):
- text = input_data.get("text", {})
- else:
- text = {
- "title": input_data.get("title", ""),
- "body": input_data.get("body_text", ""),
- }
- # 初始化状态(包含视频信息,供视频上传和后续Agent使用)
- initial_state = {
- "video": input_data.get("video", ""),
- "channel_content_id": input_data.get("channel_content_id", ""),
- "text": text,
- "topic_selection_understanding": input_data.get("topic_selection_understanding", {}),
- "content_weight": input_data.get("content_weight", {}),
- "inspiration_points": input_data.get("inspiration_points", []),
- "purpose_points": input_data.get("purpose_points", []),
- "key_points": input_data.get("key_points", [])
- }
- # 执行工作流
- result = self.compiled_graph.invoke(initial_state)
- logger.info("=== 脚本理解工作流执行完成(视频分析) ===")
- return result.get("final_result", {})
|