pattern.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. from typing import Dict, Any, Optional, List
  2. from loguru import logger
  3. import sys
  4. import json
  5. import requests
  6. from utils.params import PatternContentParam, SceneEnum, ContentTypeEnum, CapabilityEnum, ContentParam
  7. from models.task import WorkflowTask
  8. from utils.sync_mysql_help import mysql
  9. logger.add(sink=sys.stderr, level="ERROR", backtrace=True, diagnose=True)
  10. ERROR_CODE_SUCCESS = 0
  11. ERROR_CODE_FAILED = -1
  12. ERROR_CODE_TASK_CREATE_FAILED = 2001
  13. def _build_error_response(code: int, reason: str) -> Dict[str, Any]:
  14. return {
  15. "code": code,
  16. "task_id": None,
  17. "reason": reason,
  18. }
  19. def _build_success_response(task_id: str) -> Dict[str, Any]:
  20. return {
  21. "code": ERROR_CODE_SUCCESS,
  22. "task_id": task_id,
  23. "reason": "",
  24. }
  25. def _validate_pattern_param(param: PatternContentParam) -> Optional[str]:
  26. """校验聚类入参的必填项"""
  27. if not param.pattern_name:
  28. return "pattern_name 不能为空"
  29. if not param.contents:
  30. return "contents 不能为空"
  31. for idx, content in enumerate(param.contents):
  32. if not content.channel_content_id:
  33. return f"contents[{idx}].channel_content_id 不能为空"
  34. if content.weight_score is None:
  35. return f"contents[{idx}].weight_score 不能为空"
  36. return None
  37. def _validate_decode_status(contents: List[ContentParam]) -> Optional[str]:
  38. """校验每个channel_content_id的解构状态"""
  39. STATUS_SUCCESS = 2 # 成功状态
  40. if not contents:
  41. return None
  42. # 收集所有的channel_content_id
  43. channel_content_ids = [content.channel_content_id for content in contents]
  44. placeholders = ','.join(['%s'] * len(channel_content_ids))
  45. # 批量查询所有channel_content_id对应的最新task_id
  46. # 使用窗口函数获取每个channel_content_id的最新记录
  47. decode_sql = f"""
  48. SELECT channel_content_id, task_id
  49. FROM (
  50. SELECT channel_content_id, task_id,
  51. ROW_NUMBER() OVER (PARTITION BY channel_content_id ORDER BY created_time DESC) as rn
  52. FROM workflow_decode_task_result
  53. WHERE channel_content_id IN ({placeholders})
  54. ) t
  55. WHERE rn = 1
  56. """
  57. decode_params = tuple(channel_content_ids)
  58. decode_records = mysql.fetchall(decode_sql, decode_params)
  59. # 构建channel_content_id到task_id的映射
  60. content_id_to_task_id = {record['channel_content_id']: record['task_id']
  61. for record in decode_records if record.get('task_id')}
  62. # 检查是否有缺失的channel_content_id
  63. missing_ids = set(channel_content_ids) - set(content_id_to_task_id.keys())
  64. if missing_ids:
  65. missing_id = list(missing_ids)[0]
  66. return f"channel_content_id {missing_id} 找不到解构结果"
  67. # 批量查询所有task_id对应的状态
  68. task_ids = list(content_id_to_task_id.values())
  69. task_placeholders = ','.join(['%s'] * len(task_ids))
  70. task_sql = f"""
  71. SELECT task_id, status
  72. FROM workflow_task
  73. WHERE task_id IN ({task_placeholders})
  74. """
  75. task_records = mysql.fetchall(task_sql, tuple(task_ids))
  76. # 构建task_id到status的映射
  77. task_id_to_status = {record['task_id']: record['status']
  78. for record in task_records}
  79. # 验证每个channel_content_id的状态
  80. for content in contents:
  81. channel_content_id = content.channel_content_id
  82. task_id = content_id_to_task_id.get(channel_content_id)
  83. if not task_id:
  84. return f"channel_content_id {channel_content_id} 找不到解构结果"
  85. status = task_id_to_status.get(task_id)
  86. if status is None:
  87. return f"channel_content_id {channel_content_id} 找不到解构结果"
  88. if status != STATUS_SUCCESS:
  89. return f"channel_content_id {channel_content_id} 找不到解构结果"
  90. return None
  91. def _create_pattern_task(scene: SceneEnum, content_type: ContentTypeEnum) -> Optional[WorkflowTask]:
  92. """创建聚类 workflow_task 任务"""
  93. try:
  94. task = WorkflowTask.create_task(
  95. scene=scene,
  96. capability=CapabilityEnum.PATTERN,
  97. content_type=content_type,
  98. root_task_id="",
  99. )
  100. logger.info(f"创建聚类任务成功,task_id: {task.task_id}")
  101. return task
  102. except Exception as e:
  103. logger.error(f"创建聚类任务失败: {str(e)}")
  104. return None
  105. def _save_pattern_contents(task_id: str, pattern_name: str, contents: List[ContentParam]) -> bool:
  106. """将聚类内容写入 workflow_pattern_task_content 表"""
  107. if not contents:
  108. return True
  109. # 准备所有数据
  110. values_list = []
  111. params_list = []
  112. for content in contents:
  113. images_str = json.dumps(content.images or []) if isinstance(content.images, list) else ""
  114. values_list.append("(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)")
  115. params_list.extend([
  116. task_id,
  117. pattern_name,
  118. content.channel_content_id,
  119. images_str,
  120. content.title,
  121. content.channel_account_id,
  122. content.channel_account_name,
  123. content.body_text,
  124. content.video_url,
  125. content.weight_score,
  126. ])
  127. # 构建批量插入 SQL
  128. sql = f"""
  129. INSERT INTO workflow_pattern_task_content (
  130. task_id,
  131. pattern_name,
  132. channel_content_id,
  133. images,
  134. title,
  135. channel_account_id,
  136. channel_account_name,
  137. body_text,
  138. video_url,
  139. weight_score
  140. ) VALUES {', '.join(values_list)}
  141. """
  142. try:
  143. mysql.execute(sql, tuple(params_list))
  144. return True
  145. except Exception as e:
  146. logger.error(f"批量写入聚类内容失败,task_id={task_id}, error={str(e)}")
  147. return False
  148. def _trigger_pattern_workflow(task_id: str) -> Dict[str, Any]:
  149. """发起真正的聚类请求,只携带 task_id"""
  150. try:
  151. url = "http://supply-content-deconstruction-workflow.piaoquantv.com/pattern/workflow/topic/pattern"
  152. payload = {
  153. "task_id": task_id
  154. }
  155. resp = requests.post(url, json=payload, timeout=10)
  156. if resp.status_code != 200:
  157. logger.error(
  158. f"发起聚类任务失败,HTTP 状态码异常,status={resp.status_code}, task_id={task_id}"
  159. )
  160. return {
  161. "code": ERROR_CODE_FAILED,
  162. "reason": f"错误: {resp.status_code}",
  163. }
  164. try:
  165. data = resp.json()
  166. except Exception as e:
  167. logger.error(f"发起聚类任务失败,返回非JSON,task_id={task_id}, error={str(e)}")
  168. return {
  169. "code": ERROR_CODE_FAILED,
  170. "reason": "聚类工作流接口返回非JSON格式",
  171. }
  172. code = data.get("code", ERROR_CODE_FAILED)
  173. msg = data.get("msg", "")
  174. if code == 0:
  175. return {
  176. "code": ERROR_CODE_SUCCESS,
  177. "reason": "",
  178. }
  179. logger.error(
  180. f"发起聚类任务失败,上游返回错误,task_id={task_id}, code={code}, msg={msg}"
  181. )
  182. return {
  183. "code": ERROR_CODE_FAILED,
  184. "reason": f"工作流接口失败: code={code}, msg={msg}",
  185. }
  186. except requests.RequestException as e:
  187. logger.error(f"发起聚类任务失败,请求异常,task_id={task_id}, error={str(e)}")
  188. return {
  189. "code": ERROR_CODE_FAILED,
  190. "reason": f"聚类工作流接口请求异常: {str(e)}",
  191. }
  192. except Exception as e:
  193. logger.error(f"发起聚类任务失败,task_id={task_id}, error={str(e)}")
  194. return {
  195. "code": ERROR_CODE_FAILED,
  196. "reason": f"聚类任务执行失败: {str(e)}",
  197. }
  198. def begin_pattern_task(param: PatternContentParam) -> Dict[str, Any]:
  199. """创建聚类任务"""
  200. try:
  201. # 1. 校验必填项
  202. error_msg = _validate_pattern_param(param)
  203. if error_msg:
  204. return _build_error_response(ERROR_CODE_FAILED, error_msg)
  205. # 1.1 校验解构状态
  206. error_msg = _validate_decode_status(param.contents)
  207. if error_msg:
  208. return _build_error_response(ERROR_CODE_FAILED, error_msg)
  209. # 2. 创建 workflow_task 任务
  210. task = _create_pattern_task(param.scene, param.content_type)
  211. if not task or not task.task_id:
  212. return _build_error_response(
  213. ERROR_CODE_TASK_CREATE_FAILED,
  214. "创建聚类任务失败",
  215. )
  216. # 3. 将内容写入 workflow_pattern_task_content 表
  217. if not _save_pattern_contents(task.task_id, param.pattern_name, param.contents):
  218. return _build_error_response(
  219. ERROR_CODE_FAILED,
  220. "写入聚类内容失败",
  221. )
  222. # 4. 发起真正的聚类请求
  223. # trigger_result = _trigger_pattern_workflow(task.task_id)
  224. # if trigger_result.get("code") != ERROR_CODE_SUCCESS:
  225. # return _build_error_response(
  226. # ERROR_CODE_FAILED,
  227. # trigger_result.get("reason") or "发起聚类任务失败",
  228. # )
  229. # 全部成功
  230. return _build_success_response(task.task_id)
  231. except Exception as e:
  232. logger.error(f"聚类任务创建失败: {str(e)}")
  233. return _build_error_response(
  234. ERROR_CODE_TASK_CREATE_FAILED,
  235. f"聚类任务创建失败: {str(e)}",
  236. )