|
|
@@ -2,7 +2,7 @@
|
|
|
Decode Workflow.
|
|
|
|
|
|
解码工作流:合并 What 解构工作流和脚本理解工作流的完整流程。
|
|
|
-流程:视频上传 → 灵感点提取 → 目的点提取 → 关键点提取 → 选题理解 →
|
|
|
+流程:初始化数据库记录 → 视频上传 → 灵感点提取 → 目的点提取 → 关键点提取 → 选题理解 →
|
|
|
段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总
|
|
|
"""
|
|
|
|
|
|
@@ -21,6 +21,7 @@ from src.components.agents.key_points_agent import KeyPointsAgent
|
|
|
from src.components.agents.script_section_division_agent import ScriptSectionDivisionAgent
|
|
|
from src.components.agents.script_substance_extraction_agent import ScriptSubstanceExtractionAgent
|
|
|
from src.components.agents.script_form_extraction_agent import ScriptFormExtractionAgent
|
|
|
+from src.models import get_db, DecodeVideo, DecodeStatus
|
|
|
from src.utils.logger import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
@@ -31,7 +32,7 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
|
|
|
功能:
|
|
|
- 编排完整的解码流程(视频分析)
|
|
|
- - 流程:视频上传 → 灵感点提取 → 目的点提取 → 关键点提取 → 选题理解 →
|
|
|
+ - 流程:初始化数据库记录 → 视频上传 → 灵感点提取 → 目的点提取 → 关键点提取 → 选题理解 →
|
|
|
段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总
|
|
|
- 管理状态传递
|
|
|
- 仅支持单视频输入
|
|
|
@@ -92,12 +93,13 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
"""构建工作流图
|
|
|
|
|
|
完整流程:
|
|
|
- START → 视频上传 → 灵感点提取 → 目的点提取 → 关键点提取 → 选题理解 →
|
|
|
+ START → 初始化数据库记录 → 视频上传 → 灵感点提取 → 目的点提取 → 关键点提取 → 选题理解 →
|
|
|
段落划分 → 实质提取 → 形式提取 → 分离结果 → 结果汇总 → END
|
|
|
"""
|
|
|
workflow = StateGraph(dict) # 使用dict作为状态类型
|
|
|
|
|
|
# 添加所有节点
|
|
|
+ workflow.add_node("init_db_record", self._init_db_record_node)
|
|
|
workflow.add_node("video_upload", self._video_upload_node)
|
|
|
# What解构节点
|
|
|
workflow.add_node("inspiration_points_extraction", self._inspiration_points_node)
|
|
|
@@ -112,7 +114,9 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
workflow.add_node("result_aggregation", self._result_aggregation_node)
|
|
|
|
|
|
# 定义流程的边
|
|
|
- workflow.set_entry_point("video_upload")
|
|
|
+ workflow.set_entry_point("init_db_record")
|
|
|
+ # 数据库记录初始化后进入视频上传
|
|
|
+ workflow.add_edge("init_db_record", "video_upload")
|
|
|
# 视频上传后使用条件边:成功则继续,失败则终止
|
|
|
workflow.add_conditional_edges(
|
|
|
"video_upload",
|
|
|
@@ -183,7 +187,7 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
workflow.add_edge("merge_all_results", "result_aggregation")
|
|
|
workflow.add_edge("result_aggregation", END)
|
|
|
|
|
|
- logger.info("工作流图构建完成 - 完整流程:视频上传 → What解构 → 脚本理解 → 结果汇总")
|
|
|
+ logger.info("工作流图构建完成 - 完整流程:初始化数据库记录 → 视频上传 → What解构 → 脚本理解 → 结果汇总")
|
|
|
|
|
|
return workflow
|
|
|
|
|
|
@@ -206,6 +210,8 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
# 设置失败信息到状态中
|
|
|
state["workflow_failed"] = True
|
|
|
state["workflow_error"] = error_msg
|
|
|
+ # 更新数据库记录为失败状态
|
|
|
+ self._update_db_record_after_workflow(state, success=False, error_msg=error_msg)
|
|
|
return "failure"
|
|
|
|
|
|
def _check_critical_error(self, state: Dict[str, Any], error_source: str = "") -> bool:
|
|
|
@@ -326,6 +332,68 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
|
|
|
return state
|
|
|
|
|
|
+ def _init_db_record_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
+ """节点:初始化数据库记录
|
|
|
+
|
|
|
+ 根据 video_id 查询 decode_videos 表:
|
|
|
+ - 如果存在记录,则不新建,使用现有记录
|
|
|
+ - 如果不存在,则新建记录(状态为 EXECUTING)
|
|
|
+ """
|
|
|
+ logger.info("=== 执行节点:初始化数据库记录 ===")
|
|
|
+
|
|
|
+ try:
|
|
|
+ video_id = state.get("video_id", "")
|
|
|
+ task_id = state.get("task_id")
|
|
|
+
|
|
|
+ if not video_id:
|
|
|
+ logger.warning("未提供 video_id,跳过数据库记录初始化")
|
|
|
+ return state
|
|
|
+
|
|
|
+ db = next(get_db())
|
|
|
+ try:
|
|
|
+ # 根据 video_id 查询是否已有记录
|
|
|
+ existing_record = db.query(DecodeVideo).filter_by(video_id=video_id).first()
|
|
|
+
|
|
|
+ if existing_record:
|
|
|
+ # 如果存在记录,使用现有的 task_id
|
|
|
+ logger.info(f"找到已存在的数据库记录: task_id={existing_record.task_id}, video_id={video_id}")
|
|
|
+ state["db_task_id"] = existing_record.task_id
|
|
|
+ state["db_record_exists"] = True
|
|
|
+ # 更新状态为执行中
|
|
|
+ existing_record.update_status(DecodeStatus.EXECUTING)
|
|
|
+ db.commit()
|
|
|
+ else:
|
|
|
+ # 如果不存在,创建新记录
|
|
|
+ # 如果没有提供 task_id,使用 video_id 的 hash 值作为 task_id
|
|
|
+ if not task_id:
|
|
|
+ import hashlib
|
|
|
+ task_id = int(hashlib.md5(video_id.encode()).hexdigest()[:15], 16) % (10 ** 15)
|
|
|
+ logger.info(f"未提供 task_id,自动生成: {task_id}")
|
|
|
+
|
|
|
+ new_record = DecodeVideo.create(
|
|
|
+ task_id=task_id,
|
|
|
+ video_id=video_id,
|
|
|
+ status=DecodeStatus.EXECUTING
|
|
|
+ )
|
|
|
+ db.add(new_record)
|
|
|
+ db.commit()
|
|
|
+ logger.info(f"创建新的数据库记录: task_id={task_id}, video_id={video_id}")
|
|
|
+ state["db_task_id"] = task_id
|
|
|
+ state["db_record_exists"] = False
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"数据库操作失败: {e}", exc_info=True)
|
|
|
+ db.rollback()
|
|
|
+ # 数据库操作失败不影响 workflow 继续执行
|
|
|
+ finally:
|
|
|
+ db.close()
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"初始化数据库记录节点执行失败: {e}", exc_info=True)
|
|
|
+ # 数据库操作失败不影响 workflow 继续执行
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
def _inspiration_points_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
"""节点:灵感点提取(What解构)"""
|
|
|
logger.info("=== 执行节点:灵感点提取 ===")
|
|
|
@@ -756,6 +824,57 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
|
|
|
return state
|
|
|
|
|
|
+ def _update_db_record_after_workflow(
|
|
|
+ self,
|
|
|
+ state: Dict[str, Any],
|
|
|
+ success: bool,
|
|
|
+ final_result: Dict[str, Any] = None,
|
|
|
+ error_msg: str = None
|
|
|
+ ):
|
|
|
+ """工作流执行完毕后更新数据库记录
|
|
|
+
|
|
|
+ Args:
|
|
|
+ state: 工作流执行后的状态
|
|
|
+ success: 是否成功
|
|
|
+ final_result: 最终结果(成功时使用)
|
|
|
+ error_msg: 错误信息(失败时使用)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ task_id = state.get("db_task_id")
|
|
|
+ if not task_id:
|
|
|
+ logger.warning("未找到 db_task_id,跳过数据库记录更新")
|
|
|
+ return
|
|
|
+
|
|
|
+ db = next(get_db())
|
|
|
+ try:
|
|
|
+ record = db.query(DecodeVideo).filter_by(task_id=task_id).first()
|
|
|
+ if not record:
|
|
|
+ logger.warning(f"未找到 task_id={task_id} 的数据库记录,跳过更新")
|
|
|
+ return
|
|
|
+
|
|
|
+ if success:
|
|
|
+ # 更新为成功状态
|
|
|
+ import json
|
|
|
+ result_json = json.dumps(final_result, ensure_ascii=False) if final_result else None
|
|
|
+ record.update_status(DecodeStatus.SUCCESS)
|
|
|
+ record.update_result(result_json)
|
|
|
+ logger.info(f"更新数据库记录为成功: task_id={task_id}")
|
|
|
+ else:
|
|
|
+ # 更新为失败状态
|
|
|
+ record.update_status(DecodeStatus.FAILED, error_reason=error_msg)
|
|
|
+ logger.info(f"更新数据库记录为失败: task_id={task_id}, error={error_msg}")
|
|
|
+
|
|
|
+ db.commit()
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"更新数据库记录失败: {e}", exc_info=True)
|
|
|
+ db.rollback()
|
|
|
+ finally:
|
|
|
+ db.close()
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"更新数据库记录节点执行失败: {e}", exc_info=True)
|
|
|
+
|
|
|
def invoke(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
"""执行工作流(公共接口)- 视频分析版本
|
|
|
|
|
|
@@ -763,13 +882,13 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
最终解码结果
|
|
|
"""
|
|
|
logger.info("=== 开始执行解码工作流(视频分析) ===")
|
|
|
-
|
|
|
+ logger.info(f"input_data: {input_data}")
|
|
|
# 确保工作流已初始化
|
|
|
if not self.is_initialized:
|
|
|
self.initialize()
|
|
|
|
|
|
# 验证输入参数
|
|
|
- video_url = input_data.get("video", "")
|
|
|
+ video_url = input_data.get("video_url", "")
|
|
|
if not video_url:
|
|
|
error_msg = "未提供视频URL,无法执行工作流"
|
|
|
logger.error(error_msg)
|
|
|
@@ -782,7 +901,7 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
# 初始化状态(包含视频信息,供视频上传和后续Agent使用)
|
|
|
initial_state = {
|
|
|
"video": video_url,
|
|
|
- "channel_content_id": input_data.get("channel_content_id", ""),
|
|
|
+ "video_id": input_data.get("video_id", ""),
|
|
|
"title": input_data.get("title", ""),
|
|
|
"current_depth": 0,
|
|
|
"max_depth": self.max_depth,
|
|
|
@@ -790,11 +909,16 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
}
|
|
|
|
|
|
# 执行工作流
|
|
|
+ result = None
|
|
|
try:
|
|
|
result = self.compiled_graph.invoke(initial_state)
|
|
|
except Exception as e:
|
|
|
error_msg = f"工作流执行异常: {str(e)}"
|
|
|
logger.error(error_msg, exc_info=True)
|
|
|
+ # 更新数据库记录为失败状态(使用 initial_state 作为 fallback)
|
|
|
+ if result is None:
|
|
|
+ result = initial_state
|
|
|
+ self._update_db_record_after_workflow(result, success=False, error_msg=error_msg)
|
|
|
return {
|
|
|
"error": error_msg,
|
|
|
"workflow_status": "failed",
|
|
|
@@ -805,6 +929,8 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
if result.get("workflow_failed"):
|
|
|
error_msg = result.get("workflow_error", "工作流执行失败")
|
|
|
logger.error(f"工作流因错误而终止: {error_msg}")
|
|
|
+ # 更新数据库记录为失败状态
|
|
|
+ self._update_db_record_after_workflow(result, success=False, error_msg=error_msg)
|
|
|
return {
|
|
|
"error": error_msg,
|
|
|
"video_upload_error": result.get("video_upload_error"),
|
|
|
@@ -819,18 +945,25 @@ class DecodeWorkflow(BaseGraphAgent):
|
|
|
if result.get("workflow_error"):
|
|
|
error_msg = result.get("workflow_error", "工作流执行失败,未生成结果")
|
|
|
logger.error(f"工作流执行失败: {error_msg}")
|
|
|
+ # 更新数据库记录为失败状态
|
|
|
+ self._update_db_record_after_workflow(result, success=False, error_msg=error_msg)
|
|
|
return {
|
|
|
"error": error_msg,
|
|
|
"workflow_status": "failed"
|
|
|
}
|
|
|
else:
|
|
|
logger.warning("工作流执行完成,但未生成最终结果")
|
|
|
+ # 更新数据库记录为失败状态
|
|
|
+ self._update_db_record_after_workflow(result, success=False, error_msg="工作流执行完成,但未生成最终结果")
|
|
|
return {
|
|
|
"error": "工作流执行完成,但未生成最终结果",
|
|
|
"workflow_status": "incomplete",
|
|
|
"state": result
|
|
|
}
|
|
|
|
|
|
+ # 工作流执行成功,更新数据库记录
|
|
|
+ self._update_db_record_after_workflow(result, success=True, final_result=final_result)
|
|
|
+
|
|
|
logger.info("=== 解码工作流执行完成(视频分析) ===")
|
|
|
|
|
|
return final_result
|