script_workflow.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """
  2. Script Workflow.
  3. 脚本理解工作流:编排脚本段落划分和元素提取流程的执行顺序和流程逻辑。
  4. 流程:段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总
  5. """
  6. from typing import Dict, Any
  7. from langgraph.graph import StateGraph, END
  8. from src.components.agents.base import BaseGraphAgent
  9. from src.states.script_state import ScriptState
  10. from src.components.agents.script_section_division_agent import ScriptSectionDivisionAgent
  11. from src.components.agents.script_substance_extraction_agent import ScriptSubstanceExtractionAgent
  12. from src.components.agents.script_form_extraction_agent import ScriptFormExtractionAgent
  13. from src.components.functions.video_upload_function import VideoUploadFunction
  14. from src.utils.logger import get_logger
  15. logger = get_logger(__name__)
  16. class ScriptWorkflow(BaseGraphAgent):
  17. """脚本理解工作流
  18. 功能:
  19. - 编排脚本理解流程
  20. - 流程:段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总
  21. - 管理状态传递
  22. 实现方式:BaseGraphAgent (LangGraph)
  23. """
  24. def __init__(
  25. self,
  26. name: str = "script_workflow",
  27. description: str = "脚本理解工作流",
  28. model_provider: str = "google_genai"
  29. ):
  30. super().__init__(
  31. name=name,
  32. description=description,
  33. state_class=ScriptState
  34. )
  35. self.model_provider = model_provider
  36. # 初始化视频上传Function
  37. self.video_upload_func = VideoUploadFunction()
  38. # 初始化脚本段落划分Agent
  39. self.section_agent = ScriptSectionDivisionAgent(
  40. model_provider=model_provider
  41. )
  42. # 初始化实质提取Agent
  43. self.substance_agent = ScriptSubstanceExtractionAgent(
  44. model_provider=model_provider
  45. )
  46. # 初始化形式提取Agent
  47. self.form_agent = ScriptFormExtractionAgent(
  48. model_provider=model_provider
  49. )
  50. logger.info(f"ScriptWorkflow 初始化完成,model_provider: {model_provider}")
  51. def _build_graph(self) -> StateGraph:
  52. """构建工作流图(视频分析版)
  53. 流程:
  54. START → 视频上传 → 段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总 → END
  55. """
  56. workflow = StateGraph(dict) # 使用dict作为状态类型
  57. # 添加所有节点
  58. workflow.add_node("video_upload", self._video_upload_node)
  59. workflow.add_node("section_division", self._section_division_node)
  60. workflow.add_node("substance_extraction", self._substance_extraction_node)
  61. workflow.add_node("form_extraction", self._form_extraction_node)
  62. workflow.add_node("merge_all_results", self._merge_all_results_node)
  63. workflow.add_node("result_aggregation", self._result_aggregation_node)
  64. # 定义流程的边
  65. workflow.set_entry_point("video_upload")
  66. workflow.add_edge("video_upload", "section_division")
  67. workflow.add_edge("section_division", "substance_extraction")
  68. workflow.add_edge("substance_extraction", "form_extraction")
  69. workflow.add_edge("form_extraction", "merge_all_results")
  70. workflow.add_edge("merge_all_results", "result_aggregation")
  71. workflow.add_edge("result_aggregation", END)
  72. logger.info("工作流图构建完成 - 流程:视频上传 → 段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总")
  73. return workflow
  74. def _video_upload_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
  75. """节点:视频上传(第一步)- 下载视频并上传至Gemini"""
  76. logger.info("=== 执行节点:视频上传 ===")
  77. try:
  78. # 初始化Function
  79. if not self.video_upload_func.is_initialized:
  80. self.video_upload_func.initialize()
  81. # 执行视频上传
  82. result = self.video_upload_func.execute(state)
  83. # 更新状态
  84. state.update(result)
  85. video_uri = result.get("video_uploaded_uri")
  86. if video_uri:
  87. logger.info(f"视频上传完成 - URI: {video_uri}")
  88. else:
  89. error = result.get("video_upload_error", "未知错误")
  90. logger.warning(f"视频上传失败: {error}")
  91. except Exception as e:
  92. logger.error(f"视频上传失败: {e}", exc_info=True)
  93. state.update({
  94. "video_uploaded_uri": None,
  95. "video_upload_error": str(e)
  96. })
  97. return state
  98. def _section_division_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
  99. """节点:脚本段落划分"""
  100. logger.info("=== 执行节点:脚本段落划分 ===")
  101. try:
  102. # 初始化Agent
  103. if not self.section_agent.is_initialized:
  104. self.section_agent.initialize()
  105. # 执行Agent
  106. result = self.section_agent.process(state)
  107. # 更新状态
  108. state.update(result)
  109. sections = result.get("段落列表", [])
  110. content_category = result.get("内容品类", "未知")
  111. logger.info(f"脚本段落划分完成 - 内容品类: {content_category}, 段落数: {len(sections)}")
  112. except Exception as e:
  113. logger.error(f"脚本段落划分失败: {e}", exc_info=True)
  114. state.update({
  115. "内容品类": "未知品类",
  116. "段落列表": []
  117. })
  118. return state
  119. def _substance_extraction_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
  120. """节点:实质元素提取"""
  121. logger.info("=== 执行节点:实质元素提取 ===")
  122. try:
  123. # 初始化Agent
  124. if not self.substance_agent.is_initialized:
  125. self.substance_agent.initialize()
  126. # 准备状态:将段落列表包装到section_division字段中
  127. sections = state.get("段落列表", [])
  128. state["section_division"] = {"段落列表": sections}
  129. # 执行Agent
  130. result = self.substance_agent.process(state)
  131. # 更新状态
  132. state.update(result)
  133. final_elements = result.get("substance_final_elements", [])
  134. logger.info(f"实质元素提取完成 - 最终元素数: {len(final_elements)}")
  135. except Exception as e:
  136. logger.error(f"实质元素提取失败: {e}", exc_info=True)
  137. state.update({
  138. "concrete_elements": [],
  139. "concrete_concepts": [],
  140. "abstract_concepts": [],
  141. "substance_elements": [],
  142. "substance_analyzed_result": [],
  143. "substance_scored_result": {},
  144. "substance_filtered_ids": [],
  145. "substance_categorized_result": {},
  146. "substance_final_elements": []
  147. })
  148. return state
  149. def _form_extraction_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
  150. """节点:形式元素提取"""
  151. logger.info("=== 执行节点:形式元素提取 ===")
  152. try:
  153. # 初始化Agent
  154. if not self.form_agent.is_initialized:
  155. self.form_agent.initialize()
  156. # 执行Agent(依赖实质元素)
  157. result = self.form_agent.process(state)
  158. # 更新状态
  159. state.update(result)
  160. final_elements = result.get("form_final_elements", [])
  161. logger.info(f"形式元素提取完成 - 最终元素数: {len(final_elements)}")
  162. except Exception as e:
  163. logger.error(f"形式元素提取失败: {e}", exc_info=True)
  164. state.update({
  165. "concrete_element_forms": [],
  166. "concrete_concept_forms": [],
  167. "overall_forms": [],
  168. "form_elements": [],
  169. "form_analyzed_result": [],
  170. "form_scored_result": {},
  171. "form_weighted_result": {},
  172. "form_filtered_ids": [],
  173. "form_categorized_result": {},
  174. "form_final_elements": []
  175. })
  176. return state
  177. def _merge_all_results_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
  178. """节点:分离实质和形式结果(Step 7)"""
  179. logger.info("=== 执行节点:分离实质和形式结果 ===")
  180. try:
  181. # 获取实质和形式的最终元素
  182. substance_final_elements = state.get("substance_final_elements", [])
  183. form_final_elements = state.get("form_final_elements", [])
  184. # 分别存储实质列表和形式列表
  185. state["实质列表"] = substance_final_elements
  186. state["形式列表"] = form_final_elements
  187. logger.info(f"分离完成 - 实质元素: {len(substance_final_elements)}, 形式元素: {len(form_final_elements)}")
  188. except Exception as e:
  189. logger.error(f"分离结果失败: {e}", exc_info=True)
  190. state["实质列表"] = []
  191. state["形式列表"] = []
  192. return state
  193. def _result_aggregation_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
  194. """节点:结果汇总"""
  195. logger.info("=== 执行节点:结果汇总 ===")
  196. try:
  197. # 从 topic_selection_understanding 提取选题描述
  198. topic_understanding = state.get("topic_selection_understanding", {})
  199. # 兼容两种结构:直接包含主题/描述,或嵌套在"选题"键下
  200. if "选题" in topic_understanding:
  201. selected_topic = topic_understanding.get("选题", {})
  202. else:
  203. selected_topic = topic_understanding
  204. # 组装最终结果 - 实质和形式分别输出
  205. final_result = {
  206. "选题描述": {
  207. "主题": selected_topic.get("主题", ""),
  208. "描述": selected_topic.get("描述", "")
  209. },
  210. "脚本理解": {
  211. "内容品类": state.get("内容品类", "未知"),
  212. "段落列表": state.get("段落列表", []),
  213. "实质列表": state.get("实质列表", []), # 独立的实质列表
  214. "形式列表": state.get("形式列表", []), # 独立的形式列表
  215. "图片列表": state.get("images", [])
  216. },
  217. "灵感点": state.get("inspiration_points", []),
  218. "目的点": state.get("purpose_points", []),
  219. "关键点": state.get("key_points", [])
  220. }
  221. # 更新状态
  222. state["final_result"] = final_result
  223. logger.info("结果汇总完成")
  224. except Exception as e:
  225. logger.error(f"结果汇总失败: {e}", exc_info=True)
  226. state["final_result"] = {
  227. "错误": f"汇总失败: {str(e)}"
  228. }
  229. return state
  230. def invoke(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
  231. """执行工作流(公共接口)- 视频分析版
  232. Returns:
  233. 最终脚本理解结果
  234. """
  235. logger.info("=== 开始执行脚本理解工作流(视频分析) ===")
  236. # 确保工作流已初始化
  237. if not self.is_initialized:
  238. self.initialize()
  239. # 构建 text(兼容两种输入方式)
  240. if "text" in input_data and isinstance(input_data.get("text"), dict):
  241. text = input_data.get("text", {})
  242. else:
  243. text = {
  244. "title": input_data.get("title", ""),
  245. "body": input_data.get("body_text", ""),
  246. }
  247. # 初始化状态(包含视频信息,供视频上传和后续Agent使用)
  248. initial_state = {
  249. "video": input_data.get("video", ""),
  250. "channel_content_id": input_data.get("channel_content_id", ""),
  251. "text": text,
  252. "topic_selection_understanding": input_data.get("topic_selection_understanding", {}),
  253. "content_weight": input_data.get("content_weight", {}),
  254. "inspiration_points": input_data.get("inspiration_points", []),
  255. "purpose_points": input_data.get("purpose_points", []),
  256. "key_points": input_data.get("key_points", [])
  257. }
  258. # 执行工作流
  259. result = self.compiled_graph.invoke(initial_state)
  260. logger.info("=== 脚本理解工作流执行完成(视频分析) ===")
  261. return result.get("final_result", {})