| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- from typing import Dict, Any, Optional, List
- from loguru import logger
- import sys
- import json
- import requests
- from utils.params import PatternContentParam, SceneEnum, ContentTypeEnum, CapabilityEnum, ContentParam
- from models.task import WorkflowTask
- from utils.sync_mysql_help import mysql
- logger.add(sink=sys.stderr, level="ERROR", backtrace=True, diagnose=True)
- ERROR_CODE_SUCCESS = 0
- ERROR_CODE_FAILED = -1
- ERROR_CODE_TASK_CREATE_FAILED = 2001
- def _build_error_response(code: int, reason: str) -> Dict[str, Any]:
- return {
- "code": code,
- "task_id": None,
- "reason": reason,
- }
- def _build_success_response(task_id: str) -> Dict[str, Any]:
- return {
- "code": ERROR_CODE_SUCCESS,
- "task_id": task_id,
- "reason": "",
- }
- def _validate_pattern_param(param: PatternContentParam) -> Optional[str]:
- """校验聚类入参的必填项"""
- if not param.pattern_name:
- return "pattern_name 不能为空"
- if not param.contents:
- return "contents 不能为空"
- for idx, content in enumerate(param.contents):
- if not content.channel_content_id:
- return f"contents[{idx}].channel_content_id 不能为空"
- if content.weight_score is None:
- return f"contents[{idx}].weight_score 不能为空"
- return None
- def _validate_decode_status(contents: List[ContentParam]) -> Optional[str]:
- """校验每个channel_content_id的解构状态"""
- STATUS_SUCCESS = 2 # 成功状态
-
- if not contents:
- return None
-
- # 收集所有的channel_content_id
- channel_content_ids = [content.channel_content_id for content in contents]
- placeholders = ','.join(['%s'] * len(channel_content_ids))
-
- # 批量查询所有channel_content_id对应的最新task_id
- # 使用窗口函数获取每个channel_content_id的最新记录
- decode_sql = f"""
- SELECT channel_content_id, task_id
- FROM (
- SELECT channel_content_id, task_id,
- ROW_NUMBER() OVER (PARTITION BY channel_content_id ORDER BY created_time DESC) as rn
- FROM workflow_decode_task_result
- WHERE channel_content_id IN ({placeholders})
- ) t
- WHERE rn = 1
- """
- decode_params = tuple(channel_content_ids)
- decode_records = mysql.fetchall(decode_sql, decode_params)
-
- # 构建channel_content_id到task_id的映射
- content_id_to_task_id = {record['channel_content_id']: record['task_id']
- for record in decode_records if record.get('task_id')}
-
- # 检查是否有缺失的channel_content_id
- missing_ids = set(channel_content_ids) - set(content_id_to_task_id.keys())
- if missing_ids:
- missing_id = list(missing_ids)[0]
- return f"channel_content_id {missing_id} 找不到解构结果"
-
- # 批量查询所有task_id对应的状态
- task_ids = list(content_id_to_task_id.values())
- task_placeholders = ','.join(['%s'] * len(task_ids))
- task_sql = f"""
- SELECT task_id, status
- FROM workflow_task
- WHERE task_id IN ({task_placeholders})
- """
- task_records = mysql.fetchall(task_sql, tuple(task_ids))
-
- # 构建task_id到status的映射
- task_id_to_status = {record['task_id']: record['status']
- for record in task_records}
-
- # 验证每个channel_content_id的状态
- for content in contents:
- channel_content_id = content.channel_content_id
- task_id = content_id_to_task_id.get(channel_content_id)
-
- if not task_id:
- return f"channel_content_id {channel_content_id} 找不到解构结果"
-
- status = task_id_to_status.get(task_id)
- if status is None:
- return f"channel_content_id {channel_content_id} 找不到解构结果"
-
- if status != STATUS_SUCCESS:
- return f"channel_content_id {channel_content_id} 找不到解构结果"
-
- return None
- def _create_pattern_task(scene: SceneEnum, content_type: ContentTypeEnum) -> Optional[WorkflowTask]:
- """创建聚类 workflow_task 任务"""
- try:
- task = WorkflowTask.create_task(
- scene=scene,
- capability=CapabilityEnum.PATTERN,
- content_type=content_type,
- root_task_id="",
- )
- logger.info(f"创建聚类任务成功,task_id: {task.task_id}")
- return task
- except Exception as e:
- logger.error(f"创建聚类任务失败: {str(e)}")
- return None
- def _save_pattern_contents(task_id: str, pattern_name: str, contents: List[ContentParam]) -> bool:
- """将聚类内容写入 workflow_pattern_task_content 表"""
- if not contents:
- return True
-
- # 准备所有数据
- values_list = []
- params_list = []
-
- for content in contents:
- images_str = json.dumps(content.images or []) if isinstance(content.images, list) else ""
- values_list.append("(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)")
- params_list.extend([
- task_id,
- pattern_name,
- content.channel_content_id,
- images_str,
- content.title,
- content.channel_account_id,
- content.channel_account_name,
- content.body_text,
- content.video_url,
- content.weight_score,
- ])
-
- # 构建批量插入 SQL
- sql = f"""
- INSERT INTO workflow_pattern_task_content (
- task_id,
- pattern_name,
- channel_content_id,
- images,
- title,
- channel_account_id,
- channel_account_name,
- body_text,
- video_url,
- weight_score
- ) VALUES {', '.join(values_list)}
- """
-
- try:
- mysql.execute(sql, tuple(params_list))
- return True
- except Exception as e:
- logger.error(f"批量写入聚类内容失败,task_id={task_id}, error={str(e)}")
- return False
- def _trigger_pattern_workflow(task_id: str) -> Dict[str, Any]:
- """发起真正的聚类请求,只携带 task_id"""
- try:
- url = "http://supply-content-deconstruction-workflow.piaoquantv.com/pattern/workflow/topic/pattern"
- payload = {
- "task_id": task_id
- }
- resp = requests.post(url, json=payload, timeout=10)
- if resp.status_code != 200:
- logger.error(
- f"发起聚类任务失败,HTTP 状态码异常,status={resp.status_code}, task_id={task_id}"
- )
- return {
- "code": ERROR_CODE_FAILED,
- "reason": f"错误: {resp.status_code}",
- }
- try:
- data = resp.json()
- except Exception as e:
- logger.error(f"发起聚类任务失败,返回非JSON,task_id={task_id}, error={str(e)}")
- return {
- "code": ERROR_CODE_FAILED,
- "reason": "聚类工作流接口返回非JSON格式",
- }
- code = data.get("code", ERROR_CODE_FAILED)
- msg = data.get("msg", "")
- if code == 0:
- return {
- "code": ERROR_CODE_SUCCESS,
- "reason": "",
- }
- logger.error(
- f"发起聚类任务失败,上游返回错误,task_id={task_id}, code={code}, msg={msg}"
- )
- return {
- "code": ERROR_CODE_FAILED,
- "reason": f"工作流接口失败: code={code}, msg={msg}",
- }
- except requests.RequestException as e:
- logger.error(f"发起聚类任务失败,请求异常,task_id={task_id}, error={str(e)}")
- return {
- "code": ERROR_CODE_FAILED,
- "reason": f"聚类工作流接口请求异常: {str(e)}",
- }
- except Exception as e:
- logger.error(f"发起聚类任务失败,task_id={task_id}, error={str(e)}")
- return {
- "code": ERROR_CODE_FAILED,
- "reason": f"聚类任务执行失败: {str(e)}",
- }
- def begin_pattern_task(param: PatternContentParam) -> Dict[str, Any]:
- """创建聚类任务"""
- try:
- # 1. 校验必填项
- error_msg = _validate_pattern_param(param)
- if error_msg:
- return _build_error_response(ERROR_CODE_FAILED, error_msg)
-
- # 1.1 校验解构状态
- error_msg = _validate_decode_status(param.contents)
- if error_msg:
- return _build_error_response(ERROR_CODE_FAILED, error_msg)
- # 2. 创建 workflow_task 任务
- task = _create_pattern_task(param.scene, param.content_type)
- if not task or not task.task_id:
- return _build_error_response(
- ERROR_CODE_TASK_CREATE_FAILED,
- "创建聚类任务失败",
- )
- # 3. 将内容写入 workflow_pattern_task_content 表
- if not _save_pattern_contents(task.task_id, param.pattern_name, param.contents):
- return _build_error_response(
- ERROR_CODE_FAILED,
- "写入聚类内容失败",
- )
- # 4. 发起真正的聚类请求
- # trigger_result = _trigger_pattern_workflow(task.task_id)
- # if trigger_result.get("code") != ERROR_CODE_SUCCESS:
- # return _build_error_response(
- # ERROR_CODE_FAILED,
- # trigger_result.get("reason") or "发起聚类任务失败",
- # )
- # 全部成功
- return _build_success_response(task.task_id)
- except Exception as e:
- logger.error(f"聚类任务创建失败: {str(e)}")
- return _build_error_response(
- ERROR_CODE_TASK_CREATE_FAILED,
- f"聚类任务创建失败: {str(e)}",
- )
|