Просмотр исходного кода

feat(pattern): implement clustering task creation and result fetching

- Added functions to create and manage clustering tasks, including parameter validation and content saving.
- Introduced a new method to fetch clustering task results.
- Enhanced error handling and logging for better debugging and user feedback.
- Updated the PatternContentParam model to include required fields for clustering tasks.
jihuaqiang 1 месяц назад
Родитель
Сommit
a35ab2b1fd
3 измененных файлов с 273 добавлено и 29 удалено
  1. 54 22
      tasks/detail.py
  2. 217 7
      tasks/pattern.py
  3. 2 0
      utils/params.py

+ 54 - 22
tasks/detail.py

@@ -65,6 +65,21 @@ def _fetch_decode_result(task_id: str) -> Optional[Dict[str, Any]]:
     }
     }
 
 
 
 
+def _fetch_pattern_result(task_id: str) -> Optional[Dict[str, Any]]:
+    """获取聚类任务结果"""
+    sql = "SELECT result_payload, error_message, web_url FROM workflow_pattern_task_result WHERE task_id = %s"
+    result_record = mysql.fetchone(sql, (task_id,))
+
+    if not result_record:
+        return None
+
+    return {
+        "result": _parse_result_payload(result_record.get("result_payload")),
+        "error_message": result_record.get("error_message"),
+        "url": _parse_web_url(result_record.get("web_url"))
+    }
+
+
 def _build_result_data(
 def _build_result_data(
     task_id: str,
     task_id: str,
     status: int,
     status: int,
@@ -80,7 +95,7 @@ def _build_result_data(
         "reason": reason
         "reason": reason
     }
     }
 
 
-    # 对于解构任务,增加 url 字段(data.url.pointUrl / data.url.weightUrl)
+    # 对于有可视化页面的任务,增加 url 字段
     if url is not None:
     if url is not None:
         data["url"] = url
         data["url"] = url
 
 
@@ -89,26 +104,38 @@ def _build_result_data(
 
 
 def _handle_success_status(task_id: str, capability: int) -> Dict[str, Any]:
 def _handle_success_status(task_id: str, capability: int) -> Dict[str, Any]:
     """处理成功状态(status=2)"""
     """处理成功状态(status=2)"""
-    # 只有解构任务需要查询结果表
-    if capability != CapabilityEnum.DECODE.value:
-        return _build_response(_build_result_data(task_id, STATUS_SUCCESS))
-    
-    # 查询解构结果
-    decode_result = _fetch_decode_result(task_id)
-    if not decode_result:
-        return _build_response(_build_result_data(task_id, STATUS_SUCCESS))
-    
-    result_data = _build_result_data(
-        task_id=task_id,
-        status=STATUS_SUCCESS,
-        result=decode_result.get("result"),
-        reason=decode_result.get("error_message"),
-        url=decode_result.get("url")
-    )
-    
-    return _build_response(
-        result_data
-    )
+    # 解构任务
+    if capability == CapabilityEnum.DECODE.value:
+        decode_result = _fetch_decode_result(task_id)
+        if not decode_result:
+            return _build_response(_build_result_data(task_id, STATUS_SUCCESS))
+
+        result_data = _build_result_data(
+            task_id=task_id,
+            status=STATUS_SUCCESS,
+            result=decode_result.get("result"),
+            reason=decode_result.get("error_message"),
+            url=decode_result.get("url")
+        )
+        return _build_response(result_data)
+
+    # 聚类任务
+    if capability == CapabilityEnum.PATTERN.value:
+        pattern_result = _fetch_pattern_result(task_id)
+        if not pattern_result:
+            return _build_response(_build_result_data(task_id, STATUS_SUCCESS))
+
+        result_data = _build_result_data(
+            task_id=task_id,
+            status=STATUS_SUCCESS,
+            result=pattern_result.get("result"),
+            reason=pattern_result.get("error_message"),
+            url=pattern_result.get("url")
+        )
+        return _build_response(result_data)
+
+    # 其他能力:只返回基础任务信息
+    return _build_response(_build_result_data(task_id, STATUS_SUCCESS))
 
 
 
 
 def get_decode_detail_by_task_id(task_id: str) -> Optional[Dict[str, Any]]:
 def get_decode_detail_by_task_id(task_id: str) -> Optional[Dict[str, Any]]:
@@ -137,11 +164,16 @@ def get_decode_detail_by_task_id(task_id: str) -> Optional[Dict[str, Any]]:
     # 失败状态,需要返回 error_message 到 reason 字段,result 固定为 "[]"
     # 失败状态,需要返回 error_message 到 reason 字段,result 固定为 "[]"
     if status == STATUS_FAILED:
     if status == STATUS_FAILED:
         error_message: Optional[str] = None
         error_message: Optional[str] = None
-        # 仅对解构任务从结果表中查询失败原因
+        # 解构任务失败原因
         if capability == CapabilityEnum.DECODE.value:
         if capability == CapabilityEnum.DECODE.value:
             decode_result = _fetch_decode_result(task_id_value)
             decode_result = _fetch_decode_result(task_id_value)
             if decode_result:
             if decode_result:
                 error_message = decode_result.get("error_message")
                 error_message = decode_result.get("error_message")
+        # 聚类任务失败原因
+        elif capability == CapabilityEnum.PATTERN.value:
+            pattern_result = _fetch_pattern_result(task_id_value)
+            if pattern_result:
+                error_message = pattern_result.get("error_message")
         
         
         result_data = _build_result_data(
         result_data = _build_result_data(
             task_id=task_id_value,
             task_id=task_id_value,

+ 217 - 7
tasks/pattern.py

@@ -1,15 +1,225 @@
-from typing import Dict, Any
+from typing import Dict, Any, Optional, List
 
 
-from utils.params import PatternContentParam
+from loguru import logger
+import sys
+import json
+import requests
 
 
+from utils.params import PatternContentParam, SceneEnum, ContentTypeEnum, CapabilityEnum, ContentParam
+from models.task import WorkflowTask
+from utils.sync_mysql_help import mysql
 
 
+
+logger.add(sink=sys.stderr, level="ERROR", backtrace=True, diagnose=True)
+
+ERROR_CODE_SUCCESS = 0
 ERROR_CODE_FAILED = -1
 ERROR_CODE_FAILED = -1
+ERROR_CODE_TASK_CREATE_FAILED = 2001
 
 
 
 
-def begin_pattern_task(param: PatternContentParam) -> Dict[str, Any]:
-    """聚类任务暂未实现的占位实现"""
+def _build_error_response(code: int, reason: str) -> Dict[str, Any]:
     return {
     return {
-        "code": ERROR_CODE_FAILED,
+        "code": code,
         "task_id": None,
         "task_id": None,
-        "reason": "聚类任务暂未实现"
-    }
+        "reason": reason,
+    }
+
+
+def _build_success_response(task_id: str) -> Dict[str, Any]:
+    return {
+        "code": ERROR_CODE_SUCCESS,
+        "task_id": task_id,
+        "reason": "",
+    }
+
+
+def _validate_pattern_param(param: PatternContentParam) -> Optional[str]:
+    """校验聚类入参的必填项"""
+    if not param.pattern_name:
+        return "pattern_name 不能为空"
+
+    if not param.contents:
+        return "contents 不能为空"
+
+    for idx, content in enumerate(param.contents):
+        if not content.channel_content_id:
+            return f"contents[{idx}].channel_content_id 不能为空"
+        if content.weight_score is None:
+            return f"contents[{idx}].weight_score 不能为空"
+
+    return None
+
+
+def _create_pattern_task(scene: SceneEnum, content_type: ContentTypeEnum) -> Optional[WorkflowTask]:
+    """创建聚类 workflow_task 任务"""
+    try:
+        task = WorkflowTask.create_task(
+            scene=scene,
+            capability=CapabilityEnum.PATTERN,
+            content_type=content_type,
+            root_task_id="",
+        )
+        logger.info(f"创建聚类任务成功,task_id: {task.task_id}")
+        return task
+    except Exception as e:
+        logger.error(f"创建聚类任务失败: {str(e)}")
+        return None
+
+
+def _save_pattern_contents(task_id: str, contents: List[ContentParam]) -> bool:
+    """将聚类内容写入 workflow_pattern_task_content 表"""
+    sql = """
+        INSERT INTO workflow_pattern_task_content (
+            task_id,
+            channel_content_id,
+            images,
+            title,
+            channel_account_id,
+            channel_account_name,
+            body_text,
+            video_url,
+            weight_score
+        ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
+    """
+
+    for content in contents:
+        try:
+            images_str = json.dumps(content.images or []) if isinstance(content.images, list) else ""
+            params = (
+                task_id,
+                content.channel_content_id,
+                images_str,
+                content.title,
+                content.channel_account_id,
+                content.channel_account_name,
+                content.body_text,
+                content.video_url,
+                content.weight_score,
+            )
+            mysql.execute(sql, params)
+        except Exception as e:
+            logger.error(f"写入聚类内容失败,task_id={task_id}, content_id={content.channel_content_id}, error={str(e)}")
+            return False
+
+    return True
+
+
+def _trigger_pattern_workflow(task_id: str, pattern_name: str, contents: List[ContentParam]) -> Dict[str, Any]:
+    """发起真正的聚类请求"""
+    try:
+        url = "http://localhost:8100/pattern/workflow/topic/pattern"
+
+        weight_score_map = {
+            c.channel_content_id: c.weight_score for c in contents if c.weight_score is not None
+        }
+        request_contents = [
+            {
+                "channel_content_id": c.channel_content_id,
+                "title": c.title or "",
+                "body_text": c.body_text or "",
+            }
+            for c in contents
+        ]
+
+        payload = {
+            "task_id": task_id,
+            "pattern_name": pattern_name,
+            "weight_score_map": weight_score_map,
+            "contents": request_contents,
+        }
+
+        resp = requests.post(url, json=payload, timeout=10)
+
+        if resp.status_code != 0:
+            logger.error(
+                f"发起聚类任务失败,HTTP 状态码异常,status={resp.status_code}, task_id={task_id}"
+            )
+            return {
+                "code": ERROR_CODE_FAILED,
+                "reason": f"错误: {resp.status_code}",
+            }
+
+        try:
+            data = resp.json()
+        except Exception as e:
+            logger.error(f"发起聚类任务失败,返回非JSON,task_id={task_id}, error={str(e)}")
+            return {
+                "code": ERROR_CODE_FAILED,
+                "reason": "聚类工作流接口返回非JSON格式",
+            }
+
+        code = data.get("code", ERROR_CODE_FAILED)
+        msg = data.get("msg", "")
+
+        if code == 0:
+            return {
+                "code": ERROR_CODE_SUCCESS,
+                "reason": "",
+            }
+
+        logger.error(
+            f"发起聚类任务失败,上游返回错误,task_id={task_id}, code={code}, msg={msg}"
+        )
+        return {
+            "code": ERROR_CODE_FAILED,
+            "reason": f"工作流接口失败: code={code}, msg={msg}",
+        }
+
+    except requests.RequestException as e:
+        logger.error(f"发起聚类任务失败,请求异常,task_id={task_id}, error={str(e)}")
+        return {
+            "code": ERROR_CODE_FAILED,
+            "reason": f"聚类工作流接口请求异常: {str(e)}",
+        }
+    except Exception as e:
+        logger.error(f"发起聚类任务失败,task_id={task_id}, error={str(e)}")
+        return {
+            "code": ERROR_CODE_FAILED,
+            "reason": f"聚类任务执行失败: {str(e)}",
+        }
+
+
+def begin_pattern_task(param: PatternContentParam) -> Dict[str, Any]:
+    """创建聚类任务"""
+    try:
+        # 1. 校验必填项
+        error_msg = _validate_pattern_param(param)
+        if error_msg:
+            return _build_error_response(ERROR_CODE_FAILED, error_msg)
+
+        # 2. 创建 workflow_task 任务
+        task = _create_pattern_task(param.scene, param.content_type)
+        if not task or not task.task_id:
+            return _build_error_response(
+                ERROR_CODE_TASK_CREATE_FAILED,
+                "创建聚类任务失败",
+            )
+
+        # 3. 将内容写入 workflow_pattern_task_content 表
+        if not _save_pattern_contents(task.task_id, param.contents):
+            return _build_error_response(
+                ERROR_CODE_FAILED,
+                "写入聚类内容失败",
+            )
+
+        # 4. 发起真正的聚类请求
+        trigger_result = _trigger_pattern_workflow(
+            task.task_id,
+            param.pattern_name,
+            param.contents,
+        )
+        if trigger_result.get("code") != ERROR_CODE_SUCCESS:
+            return _build_error_response(
+                ERROR_CODE_FAILED,
+                trigger_result.get("reason") or "发起聚类任务失败",
+            )
+
+        # 全部成功
+        return _build_success_response(task.task_id)
+
+    except Exception as e:
+        logger.error(f"聚类任务创建失败: {str(e)}")
+        return _build_error_response(
+            ERROR_CODE_TASK_CREATE_FAILED,
+            f"聚类任务创建失败: {str(e)}",
+        )

+ 2 - 0
utils/params.py

@@ -28,6 +28,7 @@ class ContentParam(BaseModel):
     title: Optional[str] = None
     title: Optional[str] = None
     channel_account_id: Optional[str] = None
     channel_account_id: Optional[str] = None
     channel_account_name: Optional[str] = None
     channel_account_name: Optional[str] = None
+    weight_score: Optional[float] = None  # 表现力分数,聚类必传
 
 
 
 
 class DecodeContentParam(BaseModel):
 class DecodeContentParam(BaseModel):
@@ -38,4 +39,5 @@ class DecodeContentParam(BaseModel):
 class PatternContentParam(BaseModel):
 class PatternContentParam(BaseModel):
     scene: SceneEnum  # 业务场景:0选题 1创作 2制作
     scene: SceneEnum  # 业务场景:0选题 1创作 2制作
     content_type: ContentTypeEnum  # 1 文本 2图片 3 视频
     content_type: ContentTypeEnum  # 1 文本 2图片 3 视频
+    pattern_name: str  # 聚类名称
     contents: List[ContentParam]
     contents: List[ContentParam]