pai_flow_operator.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. # -*- coding: utf-8 -*-
  2. import functools
  3. import os
  4. import re
  5. import sys
  6. import time
  7. import json
  8. import pandas as pd
  9. from alibabacloud_paistudio20210202.client import Client as PaiStudio20210202Client
  10. from alibabacloud_tea_openapi import models as open_api_models
  11. from alibabacloud_paistudio20210202 import models as pai_studio_20210202_models
  12. from alibabacloud_tea_util import models as util_models
  13. from alibabacloud_tea_util.client import Client as UtilClient
  14. from alibabacloud_eas20210701.client import Client as eas20210701Client
  15. from alibabacloud_paiflow20210202 import models as paiflow_20210202_models
  16. from alibabacloud_paiflow20210202.client import Client as PAIFlow20210202Client
  17. from datetime import datetime, timedelta
  18. from odps import ODPS
  19. from ad_monitor_util import _monitor
  20. import alibabacloud_oss_v2 as oss
  21. target_names = {
  22. '样本shuffle',
  23. '模型训练-样本shufle',
  24. '模型训练-自定义',
  25. '模型增量训练',
  26. '模型导出-2',
  27. '更新EAS服务(Beta)-1',
  28. '虚拟起始节点',
  29. '二分类评估-1',
  30. '二分类评估-2',
  31. '预测结果对比'
  32. }
  33. EXPERIMENT_ID = "draft-wqgkag89sbh9v1zvut"
  34. ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy"
  35. ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  36. MAX_RETRIES = 3
  37. def retry(func):
  38. @functools.wraps(func)
  39. def wrapper(*args, **kwargs):
  40. retries = 0
  41. while retries < MAX_RETRIES:
  42. try:
  43. result = func(*args, **kwargs)
  44. if result is not False:
  45. return result
  46. except Exception as e:
  47. print(f"函数 {func.__name__} 执行时发生异常: {e},重试第 {retries + 1} 次")
  48. retries += 1
  49. print(f"函数 {func.__name__} 重试 {MAX_RETRIES} 次后仍失败。")
  50. return False
  51. return wrapper
  52. def get_odps_instance(project):
  53. odps = ODPS(
  54. access_id=ACCESS_KEY_ID,
  55. secret_access_key=ACCESS_KEY_SECRET,
  56. project=project,
  57. endpoint='http://service.cn.maxcompute.aliyun.com/api',
  58. )
  59. return odps
  60. def get_data_from_odps(project, table, num):
  61. odps = get_odps_instance(project)
  62. try:
  63. # 要查询的 SQL 语句
  64. sql = f'select * from {table} limit {num}'
  65. # 执行 SQL 查询
  66. with odps.execute_sql(sql).open_reader() as reader:
  67. df = reader.to_pandas()
  68. # 查询数量小于目标数量时 返回空
  69. if len(df) < num:
  70. return None
  71. return df
  72. except Exception as e:
  73. print(f"发生错误: {e}")
  74. def get_dict_from_odps(project, table):
  75. odps = get_odps_instance(project)
  76. try:
  77. # 要查询的 SQL 语句
  78. sql = f'select * from {table}'
  79. # 执行 SQL 查询
  80. with odps.execute_sql(sql).open_reader() as reader:
  81. data = {}
  82. for record in reader:
  83. record_list = list(record)
  84. key = record_list[0][1]
  85. value = record_list[1][1]
  86. data[key] = value
  87. return data
  88. except Exception as e:
  89. print(f"发生错误: {e}")
  90. def get_dates_between(start_date_str, end_date_str):
  91. start_date = datetime.strptime(start_date_str, '%Y%m%d')
  92. end_date = datetime.strptime(end_date_str, '%Y%m%d')
  93. dates = []
  94. current_date = start_date
  95. while current_date <= end_date:
  96. dates.append(current_date.strftime('%Y%m%d'))
  97. current_date += timedelta(days=1)
  98. return dates
  99. def read_file_to_list():
  100. try:
  101. current_dir = os.getcwd()
  102. file_path = os.path.join(current_dir, 'ad', 'holidays.txt')
  103. with open(file_path, 'r', encoding='utf-8') as file:
  104. content = file.read()
  105. return content.split('\n')
  106. except FileNotFoundError:
  107. raise Exception(f"错误:未找到 {file_path} 文件。")
  108. except Exception as e:
  109. raise Exception(f"错误:发生了一个未知错误: {e}")
  110. return []
  111. def get_previous_days_date(days):
  112. current_date = datetime.now()
  113. previous_date = current_date - timedelta(days=days)
  114. return previous_date.strftime('%Y%m%d')
  115. def remove_elements(lst1, lst2):
  116. return [element for element in lst1 if element not in lst2]
  117. def process_list(lst, append_str):
  118. # 给列表中每个元素拼接相同的字符串
  119. appended_list = [append_str + element for element in lst]
  120. # 将拼接后的列表元素用逗号拼接成一个字符串
  121. result_str = ','.join(appended_list)
  122. return result_str
  123. def get_train_data_list(date_begin):
  124. end_date = get_previous_days_date(2)
  125. date_list = get_dates_between(date_begin, end_date)
  126. filter_date_list = read_file_to_list()
  127. date_list = remove_elements(date_list, filter_date_list)
  128. return date_list
  129. # 只替换第一次匹配的'where dt in ()'中的日期
  130. def update_data_date_range(old_str, date_begin='20250605'):
  131. date_list = get_train_data_list(date_begin)
  132. train_list = ["'" + item + "'" for item in date_list]
  133. result = ','.join(train_list)
  134. start_index = old_str.find('where dt in (')
  135. if start_index != -1:
  136. equal_sign_index = start_index + len('where dt in (')
  137. # 找到下一个双引号的位置
  138. next_quote_index = old_str.find(')', equal_sign_index)
  139. if next_quote_index != -1:
  140. # 进行替换
  141. new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
  142. return new_value
  143. return None
  144. def compare_timestamp_with_today_start(time_str):
  145. # 解析时间字符串为 datetime 对象
  146. time_obj = datetime.fromisoformat(time_str)
  147. # 将其转换为时间戳
  148. target_timestamp = time_obj.timestamp()
  149. # 获取今天开始的时间
  150. today_start = datetime.combine(datetime.now().date(), datetime.min.time())
  151. # 将今天开始时间转换为时间戳
  152. today_start_timestamp = today_start.timestamp()
  153. return target_timestamp > today_start_timestamp
  154. def update_train_table(old_str, table):
  155. address = 'odps://pai_algo/tables/'
  156. train_table = address + table
  157. start_index = old_str.find('-Dtrain_tables="')
  158. if start_index != -1:
  159. # 确定等号的位置
  160. equal_sign_index = start_index + len('-Dtrain_tables="')
  161. # 找到下一个双引号的位置
  162. next_quote_index = old_str.find('"', equal_sign_index)
  163. if next_quote_index != -1:
  164. # 进行替换
  165. new_value = old_str[:equal_sign_index] + train_table + old_str[next_quote_index:]
  166. return new_value
  167. return None
  168. class PAIClient:
  169. def __init__(self):
  170. pass
  171. @staticmethod
  172. def create_client() -> PaiStudio20210202Client:
  173. """
  174. 使用AK&SK初始化账号Client
  175. @return: Client
  176. @throws Exception
  177. """
  178. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  179. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  180. config = open_api_models.Config(
  181. access_key_id=ACCESS_KEY_ID,
  182. access_key_secret=ACCESS_KEY_SECRET
  183. )
  184. # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
  185. config.endpoint = f'pai.cn-hangzhou.aliyuncs.com'
  186. return PaiStudio20210202Client(config)
  187. @staticmethod
  188. def create_eas_client() -> eas20210701Client:
  189. """
  190. 使用AK&SK初始化账号Client
  191. @return: Client
  192. @throws Exception
  193. """
  194. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  195. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  196. config = open_api_models.Config(
  197. access_key_id=ACCESS_KEY_ID,
  198. access_key_secret=ACCESS_KEY_SECRET
  199. )
  200. # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
  201. config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
  202. return eas20210701Client(config)
  203. @staticmethod
  204. def create_flow_client() -> PAIFlow20210202Client:
  205. """
  206. 使用AK&SK初始化账号Client
  207. @return: Client
  208. @throws Exception
  209. """
  210. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  211. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  212. config = open_api_models.Config(
  213. # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_ID。,
  214. access_key_id=ACCESS_KEY_ID,
  215. # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_SECRET。,
  216. access_key_secret=ACCESS_KEY_SECRET
  217. )
  218. # Endpoint 请参考 https://api.aliyun.com/product/PAIFlow
  219. config.endpoint = f'paiflow.cn-hangzhou.aliyuncs.com'
  220. return PAIFlow20210202Client(config)
  221. @staticmethod
  222. def get_work_flow_draft_list(workspace_id: str):
  223. client = PAIClient.create_client()
  224. list_experiments_request = pai_studio_20210202_models.ListExperimentsRequest(
  225. workspace_id=workspace_id
  226. )
  227. runtime = util_models.RuntimeOptions()
  228. headers = {}
  229. try:
  230. resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
  231. return resp.body.to_map()
  232. except Exception as error:
  233. raise Exception(f"get_work_flow_draft_list error {error}")
  234. @staticmethod
  235. def get_work_flow_draft(experiment_id: str):
  236. client = PAIClient.create_client()
  237. runtime = util_models.RuntimeOptions()
  238. headers = {}
  239. try:
  240. # 复制代码运行请自行打印 API 的返回值
  241. resp = client.get_experiment_with_options(experiment_id, headers, runtime)
  242. return resp.body.to_map()
  243. except Exception as error:
  244. raise Exception(f"get_work_flow_draft error {error}")
  245. @staticmethod
  246. def get_describe_service(service_name: str):
  247. client = PAIClient.create_eas_client()
  248. runtime = util_models.RuntimeOptions()
  249. headers = {}
  250. try:
  251. # 复制代码运行请自行打印 API 的返回值
  252. resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
  253. return resp.body.to_map()
  254. except Exception as error:
  255. raise Exception(f"get_describe_service error {error}")
  256. @staticmethod
  257. def update_experiment_content(experiment_id: str, content: str, version: int):
  258. client = PAIClient.create_client()
  259. update_experiment_content_request = pai_studio_20210202_models.UpdateExperimentContentRequest(content=content,
  260. version=version)
  261. runtime = util_models.RuntimeOptions()
  262. headers = {}
  263. try:
  264. # 复制代码运行请自行打印 API 的返回值
  265. resp = client.update_experiment_content_with_options(experiment_id, update_experiment_content_request,
  266. headers, runtime)
  267. print(resp.body.to_map())
  268. except Exception as error:
  269. raise Exception(f"update_experiment_content error {error}")
  270. @staticmethod
  271. def create_job(experiment_id: str, node_id: str, execute_type: str):
  272. client = PAIClient.create_client()
  273. create_job_request = pai_studio_20210202_models.CreateJobRequest()
  274. create_job_request.EXPERIMENT_ID = experiment_id
  275. create_job_request.node_id = node_id
  276. create_job_request.execute_type = execute_type
  277. runtime = util_models.RuntimeOptions()
  278. headers = {}
  279. try:
  280. # 复制代码运行请自行打印 API 的返回值
  281. resp = client.create_job_with_options(create_job_request, headers, runtime)
  282. return resp.body.to_map()
  283. except Exception as error:
  284. raise Exception(f"create_job error {error}")
  285. @staticmethod
  286. def get_jobs_list(experiment_id: str, order='DESC'):
  287. client = PAIClient.create_client()
  288. list_jobs_request = pai_studio_20210202_models.ListJobsRequest(
  289. experiment_id=experiment_id,
  290. order=order
  291. )
  292. runtime = util_models.RuntimeOptions()
  293. headers = {}
  294. try:
  295. # 复制代码运行请自行打印 API 的返回值
  296. resp = client.list_jobs_with_options(list_jobs_request, headers, runtime)
  297. return resp.body.to_map()
  298. except Exception as error:
  299. raise Exception(f"get_jobs_list error {error}")
  300. @staticmethod
  301. def get_job_detail(job_id: str, verbose=False):
  302. client = PAIClient.create_client()
  303. get_job_request = pai_studio_20210202_models.GetJobRequest(
  304. verbose=verbose
  305. )
  306. runtime = util_models.RuntimeOptions()
  307. headers = {}
  308. try:
  309. # 复制代码运行请自行打印 API 的返回值
  310. resp = client.get_job_with_options(job_id, get_job_request, headers, runtime)
  311. return resp.body.to_map()
  312. except Exception as error:
  313. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  314. # 错误 message
  315. print(error.message)
  316. # 诊断地址
  317. print(error.data.get("Recommend"))
  318. UtilClient.assert_as_string(error.message)
  319. @staticmethod
  320. def get_flow_out_put(pipeline_run_id: str, node_id: str, depth: int):
  321. client = PAIClient.create_flow_client()
  322. list_pipeline_run_node_outputs_request = paiflow_20210202_models.ListPipelineRunNodeOutputsRequest(
  323. depth=depth
  324. )
  325. runtime = util_models.RuntimeOptions()
  326. headers = {}
  327. try:
  328. # 复制代码运行请自行打印 API 的返回值
  329. resp = client.list_pipeline_run_node_outputs_with_options(pipeline_run_id, node_id,
  330. list_pipeline_run_node_outputs_request, headers,
  331. runtime)
  332. return resp.body.to_map()
  333. except Exception as error:
  334. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  335. # 错误 message
  336. print(error.message)
  337. # 诊断地址
  338. print(error.data.get("Recommend"))
  339. UtilClient.assert_as_string(error.message)
  340. def extract_date_yyyymmdd(input_string):
  341. pattern = r'\d{8}'
  342. matches = re.findall(pattern, input_string)
  343. if matches:
  344. return matches[0]
  345. return None
  346. def extract_model_name(input_string):
  347. pattern = r"ad_rank_dnn_([^/]+)/\d{8}"
  348. matches = re.findall(pattern, input_string)
  349. if matches:
  350. return matches[0]
  351. return None
  352. def get_online_model_config(service_name: str):
  353. model_config = {}
  354. model_detail = PAIClient.get_describe_service(service_name)
  355. service_config_str = model_detail['ServiceConfig']
  356. service_config = json.loads(service_config_str)
  357. model_path = service_config['model_path']
  358. model_config['model_path'] = model_path
  359. online_date = extract_date_yyyymmdd(model_path)
  360. model_config['online_date'] = online_date
  361. model_name = extract_model_name(model_path)
  362. model_config['model_name'] = model_name
  363. return model_config
  364. def update_shuffle_flow(table):
  365. draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
  366. print(json.dumps(draft, ensure_ascii=False))
  367. content = draft['Content']
  368. version = draft['Version']
  369. content_json = json.loads(content)
  370. nodes = content_json.get('nodes')
  371. for node in nodes:
  372. name = node['name']
  373. if name == '模型训练-样本shufle':
  374. properties = node['properties']
  375. for property in properties:
  376. if property['name'] == 'sql':
  377. value = property['value']
  378. new_value = update_train_table(value, table)
  379. if new_value is None:
  380. print("error")
  381. property['value'] = new_value
  382. new_content = json.dumps(content_json, ensure_ascii=False)
  383. PAIClient.update_experiment_content(EXPERIMENT_ID, new_content, version)
  384. def update_shuffle_flow_1():
  385. draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
  386. print(json.dumps(draft, ensure_ascii=False))
  387. content = draft['Content']
  388. version = draft['Version']
  389. print(content)
  390. content_json = json.loads(content)
  391. nodes = content_json.get('nodes')
  392. for node in nodes:
  393. name = node['name']
  394. if name == '模型训练-样本shufle':
  395. properties = node['properties']
  396. for property in properties:
  397. if property['name'] == 'sql':
  398. value = property['value']
  399. new_value = update_data_date_range(value)
  400. if new_value is None:
  401. print("error")
  402. property['value'] = new_value
  403. new_content = json.dumps(content_json, ensure_ascii=False)
  404. PAIClient.update_experiment_content(EXPERIMENT_ID, new_content, version)
  405. def wait_job_end(job_id: str, check_interval=300):
  406. while True:
  407. job_detail = PAIClient.get_job_detail(job_id)
  408. print(job_detail)
  409. statue = job_detail['Status']
  410. # Initialized: 初始化完成 Starting:开始 WorkflowServiceStarting:准备提交 Running:运行中 ReadyToSchedule:准备运行(前序节点未完成导致)
  411. if (statue == 'Initialized' or statue == 'Starting' or statue == 'WorkflowServiceStarting'
  412. or statue == 'Running' or statue == 'ReadyToSchedule'):
  413. time.sleep(check_interval)
  414. continue
  415. # Failed:运行失败 Terminating:终止中 Terminated:已终止 Unknown:未知 Skipped:跳过(前序节点失败导致) Succeeded:运行成功
  416. if statue == 'Failed' or statue == 'Terminating' or statue == 'Unknown' or statue == 'Skipped' or statue == 'Succeeded':
  417. return job_detail
  418. def get_node_dict():
  419. draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
  420. content = draft['Content']
  421. content_json = json.loads(content)
  422. nodes = content_json.get('nodes')
  423. node_dict = {}
  424. for node in nodes:
  425. name = node['name']
  426. # 检查名称是否在目标名称集合中
  427. if name in target_names:
  428. node_dict[name] = node['id']
  429. return node_dict
  430. def get_job_dict():
  431. job_dict = {}
  432. jobs_list = PAIClient.get_jobs_list(EXPERIMENT_ID)
  433. for job in jobs_list['Jobs']:
  434. # 解析时间字符串为 datetime 对象
  435. if not compare_timestamp_with_today_start(job['GmtCreateTime']):
  436. break
  437. job_id = job['JobId']
  438. job_detail = PAIClient.get_job_detail(job_id, verbose=True)
  439. for name in target_names:
  440. if job_detail['Status'] != 'Succeeded':
  441. continue
  442. if name in job_dict:
  443. continue
  444. if name in job_detail['RunInfo']:
  445. job_dict[name] = job_detail['JobId']
  446. return job_dict
  447. @retry
  448. def update_online_flow():
  449. try:
  450. online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
  451. draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
  452. print(json.dumps(draft, ensure_ascii=False))
  453. content = draft['Content']
  454. version = draft['Version']
  455. print(content)
  456. content_json = json.loads(content)
  457. nodes = content_json.get('nodes')
  458. global_params = content_json.get('globalParams')
  459. bizdate = get_previous_days_date(1)
  460. for global_param in global_params:
  461. try:
  462. if global_param['name'] == 'bizdate':
  463. global_param['value'] = bizdate
  464. if global_param['name'] == 'online_version_dt':
  465. global_param['value'] = online_model_config['online_date']
  466. if global_param['name'] == 'eval_date':
  467. global_param['value'] = bizdate
  468. if global_param['name'] == 'online_model_path':
  469. global_param['value'] = online_model_config['model_path']
  470. except KeyError:
  471. raise Exception("在处理全局参数时,字典中缺少必要的键")
  472. for node in nodes:
  473. try:
  474. name = node['name']
  475. if name in ('样本shuffle',):
  476. date_begin = '20250605' if name == '样本shuffle' else get_previous_days_date(10)
  477. properties = node['properties']
  478. for property in properties:
  479. if property['name'] == 'sql':
  480. value = property['value']
  481. new_value = update_data_date_range(value, date_begin)
  482. if new_value is None:
  483. print("error")
  484. property['value'] = new_value
  485. except KeyError:
  486. raise Exception("在处理节点属性时,字典中缺少必要的键")
  487. new_content = json.dumps(content_json, ensure_ascii=False)
  488. PAIClient.update_experiment_content(EXPERIMENT_ID, new_content, version)
  489. return True
  490. except json.JSONDecodeError:
  491. raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
  492. except Exception as e:
  493. raise Exception(f"发生未知错误: {e}")
  494. @retry
  495. def shuffle_table():
  496. try:
  497. node_dict = get_node_dict()
  498. train_node_id = node_dict['样本shuffle']
  499. execute_type = 'EXECUTE_FROM_HERE'
  500. validate_res = PAIClient.create_job(EXPERIMENT_ID, train_node_id, execute_type)
  501. validate_job_id = validate_res['JobId']
  502. validate_job_detail = wait_job_end(validate_job_id, 10)
  503. if validate_job_detail['Status'] == 'Succeeded':
  504. return True
  505. return False
  506. except Exception as e:
  507. error_message = f"在执行 shuffle_table 函数时发生异常: {str(e)}"
  508. print(error_message)
  509. raise Exception(error_message)
  510. @retry
  511. def shuffle_train_model():
  512. try:
  513. node_dict = get_node_dict()
  514. job_dict = get_job_dict()
  515. job_id = job_dict['样本shuffle']
  516. validate_job_detail = wait_job_end(job_id)
  517. if validate_job_detail['Status'] == 'Succeeded':
  518. pipeline_run_id = validate_job_detail['RunId']
  519. node_id = validate_job_detail['PaiflowNodeId']
  520. flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
  521. outputs = flow_out_put_detail['Outputs']
  522. table = None
  523. for output in outputs:
  524. if output["Producer"] == node_dict['样本shuffle'] and output["Name"] == "outputTable":
  525. value1 = json.loads(output["Info"]['value'])
  526. table = value1['location']['table']
  527. if table is not None:
  528. update_shuffle_flow(table)
  529. node_dict = get_node_dict()
  530. train_node_id = node_dict['模型训练-样本shufle']
  531. execute_type = 'EXECUTE_ONE'
  532. train_res = PAIClient.create_job(EXPERIMENT_ID, train_node_id, execute_type)
  533. train_job_id = train_res['JobId']
  534. train_job_detail = wait_job_end(train_job_id)
  535. if train_job_detail['Status'] == 'Succeeded':
  536. return True
  537. return False
  538. except Exception as e:
  539. error_message = f"在执行 shuffle_train_model 函数时发生异常: {str(e)}"
  540. print(error_message)
  541. raise Exception(error_message)
  542. @retry
  543. def export_model():
  544. try:
  545. node_dict = get_node_dict()
  546. export_node_id = node_dict['模型导出-2']
  547. execute_type = 'EXECUTE_ONE'
  548. export_res = PAIClient.create_job(EXPERIMENT_ID, export_node_id, execute_type)
  549. export_job_id = export_res['JobId']
  550. export_job_detail = wait_job_end(export_job_id)
  551. if export_job_detail['Status'] == 'Succeeded':
  552. return True
  553. return False
  554. except Exception as e:
  555. error_message = f"在执行 export_model 函数时发生异常: {str(e)}"
  556. print(error_message)
  557. raise Exception(error_message)
  558. def update_online_model():
  559. try:
  560. node_dict = get_node_dict()
  561. train_node_id = node_dict['更新EAS服务(Beta)-1']
  562. execute_type = 'EXECUTE_ONE'
  563. train_res = PAIClient.create_job(EXPERIMENT_ID, train_node_id, execute_type)
  564. train_job_id = train_res['JobId']
  565. train_job_detail = wait_job_end(train_job_id)
  566. if train_job_detail['Status'] == 'Succeeded':
  567. return True
  568. return False
  569. except Exception as e:
  570. error_message = f"在执行 update_online_model 函数时发生异常: {str(e)}"
  571. print(error_message)
  572. raise Exception(error_message)
  573. @retry
  574. def get_validate_model_data():
  575. try:
  576. node_dict = get_node_dict()
  577. train_node_id = node_dict['虚拟起始节点']
  578. execute_type = 'EXECUTE_FROM_HERE'
  579. validate_res = PAIClient.create_job(EXPERIMENT_ID, train_node_id, execute_type)
  580. validate_job_id = validate_res['JobId']
  581. validate_job_detail = wait_job_end(validate_job_id)
  582. if validate_job_detail['Status'] == 'Succeeded':
  583. return True
  584. return False
  585. except Exception as e:
  586. error_message = f"在执行 get_validate_model_data 函数时出现异常: {e}"
  587. print(error_message)
  588. raise Exception(error_message)
  589. def validate_model_data_accuracy():
  590. try:
  591. table_dict = {}
  592. node_dict = get_node_dict()
  593. job_dict = get_job_dict()
  594. job_id = job_dict['虚拟起始节点']
  595. validate_job_detail = wait_job_end(job_id)
  596. if validate_job_detail['Status'] == 'Succeeded':
  597. pipeline_run_id = validate_job_detail['RunId']
  598. node_id = validate_job_detail['PaiflowNodeId']
  599. flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
  600. print(flow_out_put_detail)
  601. outputs = flow_out_put_detail['Outputs']
  602. for output in outputs:
  603. if output["Producer"] == node_dict['二分类评估-1'] and output["Name"] == "outputMetricTable":
  604. value1 = json.loads(output["Info"]['value'])
  605. table_dict['二分类评估-1'] = value1['location']['table']
  606. if output["Producer"] == node_dict['二分类评估-2'] and output["Name"] == "outputMetricTable":
  607. value2 = json.loads(output["Info"]['value'])
  608. table_dict['二分类评估-2'] = value2['location']['table']
  609. if output["Producer"] == node_dict['预测结果对比'] and output["Name"] == "outputTable":
  610. value3 = json.loads(output["Info"]['value'])
  611. table_dict['预测结果对比'] = value3['location']['table']
  612. num = 10
  613. df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)
  614. # 对指定列取绝对值再求和
  615. old_abs_avg = df['old_error'].abs().sum() / num
  616. new_abs_avg = df['new_error'].abs().sum() / num
  617. new_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-1'])['AUC']
  618. old_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-2'])['AUC']
  619. bizdate = get_previous_days_date(1)
  620. score_diff = abs(old_abs_avg - new_abs_avg)
  621. msg = ""
  622. result = False
  623. if new_abs_avg > 0.1:
  624. msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
  625. level = 'error'
  626. elif score_diff > 0.05 and new_abs_avg - old_abs_avg > 0.05:
  627. msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
  628. level = 'error'
  629. else:
  630. msg += 'DNN广告模型更新完成'
  631. level = 'info'
  632. result = True
  633. # 初始化表格头部
  634. top10_msg = "| CID | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
  635. top10_msg += "\n| ---- | --------- | -------- |"
  636. for index, row in df.iterrows():
  637. # 获取指定列的元素
  638. cid = row['cid']
  639. old_error = row['old_error']
  640. new_error = row['new_error']
  641. top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
  642. print(top10_msg)
  643. msg += f"\n\t - 老模型AUC: {old_auc}"
  644. msg += f"\n\t - 新模型AUC: {new_auc}"
  645. msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
  646. msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
  647. return result, msg, level, top10_msg
  648. except Exception as e:
  649. error_message = f"在执行 validate_model_data_accuracy 函数时出现异常: {str(e)}"
  650. print(error_message)
  651. raise Exception(error_message)
  652. def update_trained_cids_pointer(model_name=None, dt_version=None):
  653. # 如均为空,则从工作流中获取
  654. if not model_name and not dt_version:
  655. draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
  656. content = draft['Content']
  657. content_json = json.loads(content)
  658. global_params = content_json.get('globalParams', [])
  659. model_name = None
  660. dt_version = None
  661. for param in global_params:
  662. if param.get('name') == 'model_name':
  663. model_name = param.get('value')
  664. if param.get('name') == 'bizdate':
  665. dt_version = param.get('value')
  666. if not model_name or not dt_version:
  667. raise Exception("globalParams 中未找到 model_name 或 bizdate")
  668. elif not (model_name and dt_version):
  669. # 不允许其中一个为空
  670. raise Exception("model_name 和 dt_version 必须同时提供")
  671. model_version = {}
  672. model_version['modelName'] = f"model_name={model_name}"
  673. model_version['dtVersion'] = f"dt_version={dt_version}"
  674. model_version['timestamp'] = int(time.time())
  675. print(json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8'))
  676. bucket_name = "art-recommend"
  677. object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
  678. oss_config = oss.config.load_default()
  679. oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
  680. access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
  681. )
  682. oss_config.region = "cn-hangzhou"
  683. client = oss.Client(oss_config)
  684. ret = client.put_object(oss.PutObjectRequest(
  685. bucket=bucket_name,
  686. key=object_key,
  687. body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
  688. ))
  689. print(f'status code: {ret.status_code},'
  690. f' request id: {ret.request_id},'
  691. f' content md5: {ret.content_md5},'
  692. f' etag: {ret.etag},'
  693. f' hash crc64: {ret.hash_crc64},'
  694. f' version id: {ret.version_id},'
  695. )
  696. if __name__ == '__main__':
  697. start_time = int(time.time())
  698. functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
  699. function_names = [func.__name__ for func in functions]
  700. start_function = None
  701. if len(sys.argv) > 1:
  702. start_function = sys.argv[1]
  703. if start_function not in function_names:
  704. print(f"指定的起始函数 {start_function} 不存在,请选择以下函数之一:{', '.join(function_names)}")
  705. sys.exit(1)
  706. start_index = 0
  707. if start_function:
  708. start_index = function_names.index(start_function)
  709. for func in functions[start_index:]:
  710. if not func():
  711. print(f"{func.__name__} 执行失败,后续函数不再执行。")
  712. step_end_time = int(time.time())
  713. elapsed = step_end_time - start_time
  714. _monitor('error', f"DNN模型更新,{func.__name__} 执行失败,后续函数不再执行,请检查", start_time, elapsed, None)
  715. break
  716. else:
  717. print("所有函数都成功执行,可以继续下一步操作。")
  718. result, msg, level, top10_msg = validate_model_data_accuracy()
  719. if result:
  720. update_online_res = update_online_model()
  721. if update_online_res:
  722. online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
  723. update_trained_cids_pointer(online_model_config['model_name'], online_model_config['online_date'])
  724. print("success")
  725. step_end_time = int(time.time())
  726. elapsed = step_end_time - start_time
  727. print(level, msg, start_time, elapsed, top10_msg)
  728. _monitor(level, msg, start_time, elapsed, top10_msg)