#24 聚类接口耗时优化

Gabung
jihuaqiang menggabungkan 1 komit dari weapp/dev_api_init menjadi weapp/master%! (template.HTML=3 minggu lalu)s
1 mengubah file dengan 83 tambahan dan 44 penghapusan
  1. 83 44
      tasks/pattern.py

+ 83 - 44
tasks/pattern.py

@@ -54,34 +54,64 @@ def _validate_decode_status(contents: List[ContentParam]) -> Optional[str]:
     """校验每个channel_content_id的解构状态"""
     STATUS_SUCCESS = 2  # 成功状态
     
+    if not contents:
+        return None
+    
+    # 收集所有的channel_content_id
+    channel_content_ids = [content.channel_content_id for content in contents]
+    placeholders = ','.join(['%s'] * len(channel_content_ids))
+    
+    # 批量查询所有channel_content_id对应的最新task_id
+    # 使用窗口函数获取每个channel_content_id的最新记录
+    decode_sql = f"""
+        SELECT channel_content_id, task_id
+        FROM (
+            SELECT channel_content_id, task_id,
+                   ROW_NUMBER() OVER (PARTITION BY channel_content_id ORDER BY created_time DESC) as rn
+            FROM workflow_decode_task_result
+            WHERE channel_content_id IN ({placeholders})
+        ) t
+        WHERE rn = 1
+    """
+    decode_params = tuple(channel_content_ids)
+    decode_records = mysql.fetchall(decode_sql, decode_params)
+    
+    # 构建channel_content_id到task_id的映射
+    content_id_to_task_id = {record['channel_content_id']: record['task_id'] 
+                             for record in decode_records if record.get('task_id')}
+    
+    # 检查是否有缺失的channel_content_id
+    missing_ids = set(channel_content_ids) - set(content_id_to_task_id.keys())
+    if missing_ids:
+        missing_id = list(missing_ids)[0]
+        return f"channel_content_id {missing_id} 找不到解构结果"
+    
+    # 批量查询所有task_id对应的状态
+    task_ids = list(content_id_to_task_id.values())
+    task_placeholders = ','.join(['%s'] * len(task_ids))
+    task_sql = f"""
+        SELECT task_id, status 
+        FROM workflow_task 
+        WHERE task_id IN ({task_placeholders})
+    """
+    task_records = mysql.fetchall(task_sql, tuple(task_ids))
+    
+    # 构建task_id到status的映射
+    task_id_to_status = {record['task_id']: record['status'] 
+                         for record in task_records}
+    
+    # 验证每个channel_content_id的状态
     for content in contents:
         channel_content_id = content.channel_content_id
+        task_id = content_id_to_task_id.get(channel_content_id)
         
-        # 查询workflow_decode_task_result表,获取最新的解构任务记录
-        sql = """
-            SELECT task_id 
-            FROM workflow_decode_task_result 
-            WHERE channel_content_id = %s 
-            ORDER BY created_time DESC 
-            LIMIT 1
-        """
-        result_record = mysql.fetchone(sql, (channel_content_id,))
-        
-        if not result_record:
-            return f"channel_content_id {channel_content_id} 找不到解构结果"
-        
-        task_id = result_record.get("task_id")
         if not task_id:
             return f"channel_content_id {channel_content_id} 找不到解构结果"
         
-        # 查询workflow_task表,获取任务状态
-        task_sql = "SELECT status FROM workflow_task WHERE task_id = %s"
-        task_record = mysql.fetchone(task_sql, (task_id,))
-        
-        if not task_record:
+        status = task_id_to_status.get(task_id)
+        if status is None:
             return f"channel_content_id {channel_content_id} 找不到解构结果"
         
-        status = task_record.get("status")
         if status != STATUS_SUCCESS:
             return f"channel_content_id {channel_content_id} 找不到解构结果"
     
@@ -106,7 +136,31 @@ def _create_pattern_task(scene: SceneEnum, content_type: ContentTypeEnum) -> Opt
 
 def _save_pattern_contents(task_id: str, pattern_name: str, contents: List[ContentParam]) -> bool:
     """将聚类内容写入 workflow_pattern_task_content 表"""
-    sql = """
+    if not contents:
+        return True
+    
+    # 准备所有数据
+    values_list = []
+    params_list = []
+    
+    for content in contents:
+        images_str = json.dumps(content.images or []) if isinstance(content.images, list) else ""
+        values_list.append("(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)")
+        params_list.extend([
+            task_id,
+            pattern_name,
+            content.channel_content_id,
+            images_str,
+            content.title,
+            content.channel_account_id,
+            content.channel_account_name,
+            content.body_text,
+            content.video_url,
+            content.weight_score,
+        ])
+    
+    # 构建批量插入 SQL
+    sql = f"""
         INSERT INTO workflow_pattern_task_content (
             task_id,
             pattern_name,
@@ -118,30 +172,15 @@ def _save_pattern_contents(task_id: str, pattern_name: str, contents: List[Conte
             body_text,
             video_url,
             weight_score
-        ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
+        ) VALUES {', '.join(values_list)}
     """
-
-    for content in contents:
-        try:
-            images_str = json.dumps(content.images or []) if isinstance(content.images, list) else ""
-            params = (
-                task_id,
-                pattern_name,
-                content.channel_content_id,
-                images_str,
-                content.title,
-                content.channel_account_id,
-                content.channel_account_name,
-                content.body_text,
-                content.video_url,
-                content.weight_score,
-            )
-            mysql.execute(sql, params)
-        except Exception as e:
-            logger.error(f"写入聚类内容失败,task_id={task_id}, content_id={content.channel_content_id}, error={str(e)}")
-            return False
-
-    return True
+    
+    try:
+        mysql.execute(sql, tuple(params_list))
+        return True
+    except Exception as e:
+        logger.error(f"批量写入聚类内容失败,task_id={task_id}, error={str(e)}")
+        return False
 
 
 def _trigger_pattern_workflow(task_id: str) -> Dict[str, Any]: