pai_flow_operator.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793
  1. # -*- coding: utf-8 -*-
  2. import functools
  3. import os
  4. import re
  5. import sys
  6. import time
  7. import json
  8. import pandas as pd
  9. from alibabacloud_paistudio20210202.client import Client as PaiStudio20210202Client
  10. from alibabacloud_tea_openapi import models as open_api_models
  11. from alibabacloud_paistudio20210202 import models as pai_studio_20210202_models
  12. from alibabacloud_tea_util import models as util_models
  13. from alibabacloud_tea_util.client import Client as UtilClient
  14. from alibabacloud_eas20210701.client import Client as eas20210701Client
  15. from alibabacloud_paiflow20210202 import models as paiflow_20210202_models
  16. from alibabacloud_paiflow20210202.client import Client as PAIFlow20210202Client
  17. from datetime import datetime, timedelta
  18. from odps import ODPS
  19. from ad_monitor_util import _monitor
  20. target_names = {
  21. '样本shuffle',
  22. '评估样本重组',
  23. '模型训练-样本shufle',
  24. '模型训练-自定义',
  25. '模型增量训练',
  26. '模型导出-2',
  27. '更新EAS服务(Beta)-1',
  28. '虚拟起始节点',
  29. '二分类评估-1',
  30. '二分类评估-2',
  31. '预测结果对比'
  32. }
  33. experiment_id = "draft-wqgkag89sbh9v1zvut"
  34. MAX_RETRIES = 3
  35. def retry(func):
  36. @functools.wraps(func)
  37. def wrapper(*args, **kwargs):
  38. retries = 0
  39. while retries < MAX_RETRIES:
  40. try:
  41. result = func(*args, **kwargs)
  42. if result is not False:
  43. return result
  44. except Exception as e:
  45. print(f"函数 {func.__name__} 执行时发生异常: {e},重试第 {retries + 1} 次")
  46. retries += 1
  47. print(f"函数 {func.__name__} 重试 {MAX_RETRIES} 次后仍失败。")
  48. return False
  49. return wrapper
  50. def get_odps_instance(project):
  51. odps = ODPS(
  52. access_id='LTAIWYUujJAm7CbH',
  53. secret_access_key='RfSjdiWwED1sGFlsjXv0DlfTnZTG1P',
  54. project=project,
  55. endpoint='http://service.cn.maxcompute.aliyun.com/api',
  56. )
  57. return odps
  58. def get_data_from_odps(project, table, num):
  59. odps = get_odps_instance(project)
  60. try:
  61. # 要查询的 SQL 语句
  62. sql = f'select * from {table} limit {num}'
  63. # 执行 SQL 查询
  64. with odps.execute_sql(sql).open_reader() as reader:
  65. df = reader.to_pandas()
  66. # 查询数量小于目标数量时 返回空
  67. if len(df) < num:
  68. return None
  69. return df
  70. except Exception as e:
  71. print(f"发生错误: {e}")
  72. def get_dict_from_odps(project, table):
  73. odps = get_odps_instance(project)
  74. try:
  75. # 要查询的 SQL 语句
  76. sql = f'select * from {table}'
  77. # 执行 SQL 查询
  78. with odps.execute_sql(sql).open_reader() as reader:
  79. data = {}
  80. for record in reader:
  81. record_list = list(record)
  82. key = record_list[0][1]
  83. value = record_list[1][1]
  84. data[key] = value
  85. return data
  86. except Exception as e:
  87. print(f"发生错误: {e}")
  88. def get_dates_between(start_date_str, end_date_str):
  89. start_date = datetime.strptime(start_date_str, '%Y%m%d')
  90. end_date = datetime.strptime(end_date_str, '%Y%m%d')
  91. dates = []
  92. current_date = start_date
  93. while current_date <= end_date:
  94. dates.append(current_date.strftime('%Y%m%d'))
  95. current_date += timedelta(days=1)
  96. return dates
  97. def read_file_to_list():
  98. try:
  99. current_dir = os.getcwd()
  100. file_path = os.path.join(current_dir, 'ad', 'holidays.txt')
  101. with open(file_path, 'r', encoding='utf-8') as file:
  102. content = file.read()
  103. return content.split('\n')
  104. except FileNotFoundError:
  105. raise Exception(f"错误:未找到 {file_path} 文件。")
  106. except Exception as e:
  107. raise Exception(f"错误:发生了一个未知错误: {e}")
  108. return []
  109. def get_previous_days_date(days):
  110. current_date = datetime.now()
  111. previous_date = current_date - timedelta(days=days)
  112. return previous_date.strftime('%Y%m%d')
  113. def remove_elements(lst1, lst2):
  114. return [element for element in lst1 if element not in lst2]
  115. def process_list(lst, append_str):
  116. # 给列表中每个元素拼接相同的字符串
  117. appended_list = [append_str + element for element in lst]
  118. # 将拼接后的列表元素用逗号拼接成一个字符串
  119. result_str = ','.join(appended_list)
  120. return result_str
  121. def get_train_data_list():
  122. start_date = '20250320'
  123. end_date = get_previous_days_date(1)
  124. date_list = get_dates_between(start_date, end_date)
  125. filter_date_list = read_file_to_list()
  126. date_list = remove_elements(date_list, filter_date_list)
  127. return date_list
  128. def update_data_date_range(old_str):
  129. date_list = get_train_data_list()
  130. train_list = ["'" + item + "'" for item in date_list]
  131. result = ','.join(train_list)
  132. start_index = old_str.find('where dt in (')
  133. if start_index != -1:
  134. equal_sign_index = start_index + len('where dt in (')
  135. # 找到下一个双引号的位置
  136. next_quote_index = old_str.find(')', equal_sign_index)
  137. if next_quote_index != -1:
  138. # 进行替换
  139. new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
  140. return new_value
  141. return None
  142. def compare_timestamp_with_today_start(time_str):
  143. # 解析时间字符串为 datetime 对象
  144. time_obj = datetime.fromisoformat(time_str)
  145. # 将其转换为时间戳
  146. target_timestamp = time_obj.timestamp()
  147. # 获取今天开始的时间
  148. today_start = datetime.combine(datetime.now().date(), datetime.min.time())
  149. # 将今天开始时间转换为时间戳
  150. today_start_timestamp = today_start.timestamp()
  151. return target_timestamp > today_start_timestamp
  152. def update_train_table(old_str, table):
  153. address = 'odps://pai_algo/tables/'
  154. train_table = address + table
  155. start_index = old_str.find('-Dtrain_tables="')
  156. if start_index != -1:
  157. # 确定等号的位置
  158. equal_sign_index = start_index + len('-Dtrain_tables="')
  159. # 找到下一个双引号的位置
  160. next_quote_index = old_str.find('"', equal_sign_index)
  161. if next_quote_index != -1:
  162. # 进行替换
  163. new_value = old_str[:equal_sign_index] + train_table + old_str[next_quote_index:]
  164. return new_value
  165. return None
  166. class PAIClient:
  167. def __init__(self):
  168. pass
  169. @staticmethod
  170. def create_client() -> PaiStudio20210202Client:
  171. """
  172. 使用AK&SK初始化账号Client
  173. @return: Client
  174. @throws Exception
  175. """
  176. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  177. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  178. config = open_api_models.Config(
  179. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  180. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  181. )
  182. # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
  183. config.endpoint = f'pai.cn-hangzhou.aliyuncs.com'
  184. return PaiStudio20210202Client(config)
  185. @staticmethod
  186. def create_eas_client() -> eas20210701Client:
  187. """
  188. 使用AK&SK初始化账号Client
  189. @return: Client
  190. @throws Exception
  191. """
  192. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  193. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  194. config = open_api_models.Config(
  195. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  196. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  197. )
  198. # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
  199. config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
  200. return eas20210701Client(config)
  201. @staticmethod
  202. def create_flow_client() -> PAIFlow20210202Client:
  203. """
  204. 使用AK&SK初始化账号Client
  205. @return: Client
  206. @throws Exception
  207. """
  208. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  209. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  210. config = open_api_models.Config(
  211. # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_ID。,
  212. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  213. # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_SECRET。,
  214. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  215. )
  216. # Endpoint 请参考 https://api.aliyun.com/product/PAIFlow
  217. config.endpoint = f'paiflow.cn-hangzhou.aliyuncs.com'
  218. return PAIFlow20210202Client(config)
  219. @staticmethod
  220. def get_work_flow_draft_list(workspace_id: str):
  221. client = PAIClient.create_client()
  222. list_experiments_request = pai_studio_20210202_models.ListExperimentsRequest(
  223. workspace_id=workspace_id
  224. )
  225. runtime = util_models.RuntimeOptions()
  226. headers = {}
  227. try:
  228. resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
  229. return resp.body.to_map()
  230. except Exception as error:
  231. raise Exception(f"get_work_flow_draft_list error {error}")
  232. @staticmethod
  233. def get_work_flow_draft(experiment_id: str):
  234. client = PAIClient.create_client()
  235. runtime = util_models.RuntimeOptions()
  236. headers = {}
  237. try:
  238. # 复制代码运行请自行打印 API 的返回值
  239. resp = client.get_experiment_with_options(experiment_id, headers, runtime)
  240. return resp.body.to_map()
  241. except Exception as error:
  242. raise Exception(f"get_work_flow_draft error {error}")
  243. @staticmethod
  244. def get_describe_service(service_name: str):
  245. client = PAIClient.create_eas_client()
  246. runtime = util_models.RuntimeOptions()
  247. headers = {}
  248. try:
  249. # 复制代码运行请自行打印 API 的返回值
  250. resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
  251. return resp.body.to_map()
  252. except Exception as error:
  253. raise Exception(f"get_describe_service error {error}")
  254. @staticmethod
  255. def update_experiment_content(experiment_id: str, content: str, version: int):
  256. client = PAIClient.create_client()
  257. update_experiment_content_request = pai_studio_20210202_models.UpdateExperimentContentRequest(content=content,
  258. version=version)
  259. runtime = util_models.RuntimeOptions()
  260. headers = {}
  261. try:
  262. # 复制代码运行请自行打印 API 的返回值
  263. resp = client.update_experiment_content_with_options(experiment_id, update_experiment_content_request,
  264. headers, runtime)
  265. print(resp.body.to_map())
  266. except Exception as error:
  267. raise Exception(f"update_experiment_content error {error}")
  268. @staticmethod
  269. def create_job(experiment_id: str, node_id: str, execute_type: str):
  270. client = PAIClient.create_client()
  271. create_job_request = pai_studio_20210202_models.CreateJobRequest()
  272. create_job_request.experiment_id = experiment_id
  273. create_job_request.node_id = node_id
  274. create_job_request.execute_type = execute_type
  275. runtime = util_models.RuntimeOptions()
  276. headers = {}
  277. try:
  278. # 复制代码运行请自行打印 API 的返回值
  279. resp = client.create_job_with_options(create_job_request, headers, runtime)
  280. return resp.body.to_map()
  281. except Exception as error:
  282. raise Exception(f"create_job error {error}")
  283. @staticmethod
  284. def get_jobs_list(experiment_id: str, order='DESC'):
  285. client = PAIClient.create_client()
  286. list_jobs_request = pai_studio_20210202_models.ListJobsRequest(
  287. experiment_id=experiment_id,
  288. order=order
  289. )
  290. runtime = util_models.RuntimeOptions()
  291. headers = {}
  292. try:
  293. # 复制代码运行请自行打印 API 的返回值
  294. resp = client.list_jobs_with_options(list_jobs_request, headers, runtime)
  295. return resp.body.to_map()
  296. except Exception as error:
  297. raise Exception(f"get_jobs_list error {error}")
  298. @staticmethod
  299. def get_job_detail(job_id: str, verbose=False):
  300. client = PAIClient.create_client()
  301. get_job_request = pai_studio_20210202_models.GetJobRequest(
  302. verbose=verbose
  303. )
  304. runtime = util_models.RuntimeOptions()
  305. headers = {}
  306. try:
  307. # 复制代码运行请自行打印 API 的返回值
  308. resp = client.get_job_with_options(job_id, get_job_request, headers, runtime)
  309. return resp.body.to_map()
  310. except Exception as error:
  311. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  312. # 错误 message
  313. print(error.message)
  314. # 诊断地址
  315. print(error.data.get("Recommend"))
  316. UtilClient.assert_as_string(error.message)
  317. @staticmethod
  318. def get_flow_out_put(pipeline_run_id: str, node_id: str, depth: int):
  319. client = PAIClient.create_flow_client()
  320. list_pipeline_run_node_outputs_request = paiflow_20210202_models.ListPipelineRunNodeOutputsRequest(
  321. depth=depth
  322. )
  323. runtime = util_models.RuntimeOptions()
  324. headers = {}
  325. try:
  326. # 复制代码运行请自行打印 API 的返回值
  327. resp = client.list_pipeline_run_node_outputs_with_options(pipeline_run_id, node_id,
  328. list_pipeline_run_node_outputs_request, headers,
  329. runtime)
  330. return resp.body.to_map()
  331. except Exception as error:
  332. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  333. # 错误 message
  334. print(error.message)
  335. # 诊断地址
  336. print(error.data.get("Recommend"))
  337. UtilClient.assert_as_string(error.message)
  338. def extract_date_yyyymmdd(input_string):
  339. pattern = r'\d{8}'
  340. matches = re.findall(pattern, input_string)
  341. if matches:
  342. return matches[0]
  343. return None
  344. def get_online_model_config(service_name: str):
  345. model_config = {}
  346. model_detail = PAIClient.get_describe_service(service_name)
  347. service_config_str = model_detail['ServiceConfig']
  348. service_config = json.loads(service_config_str)
  349. model_path = service_config['model_path']
  350. model_config['model_path'] = model_path
  351. online_date = extract_date_yyyymmdd(model_path)
  352. model_config['online_date'] = online_date
  353. return model_config
  354. def update_shuffle_flow(table):
  355. draft = PAIClient.get_work_flow_draft(experiment_id)
  356. print(json.dumps(draft, ensure_ascii=False))
  357. content = draft['Content']
  358. version = draft['Version']
  359. content_json = json.loads(content)
  360. nodes = content_json.get('nodes')
  361. for node in nodes:
  362. name = node['name']
  363. if name == '模型训练-样本shufle':
  364. properties = node['properties']
  365. for property in properties:
  366. if property['name'] == 'sql':
  367. value = property['value']
  368. new_value = update_train_table(value, table)
  369. if new_value is None:
  370. print("error")
  371. property['value'] = new_value
  372. new_content = json.dumps(content_json, ensure_ascii=False)
  373. PAIClient.update_experiment_content(experiment_id, new_content, version)
  374. def update_shuffle_flow_1():
  375. draft = PAIClient.get_work_flow_draft(experiment_id)
  376. print(json.dumps(draft, ensure_ascii=False))
  377. content = draft['Content']
  378. version = draft['Version']
  379. print(content)
  380. content_json = json.loads(content)
  381. nodes = content_json.get('nodes')
  382. for node in nodes:
  383. name = node['name']
  384. if name == '模型训练-样本shufle':
  385. properties = node['properties']
  386. for property in properties:
  387. if property['name'] == 'sql':
  388. value = property['value']
  389. new_value = update_data_date_range(value)
  390. if new_value is None:
  391. print("error")
  392. property['value'] = new_value
  393. new_content = json.dumps(content_json, ensure_ascii=False)
  394. PAIClient.update_experiment_content(experiment_id, new_content, version)
  395. def wait_job_end(job_id: str, check_interval=300):
  396. while True:
  397. job_detail = PAIClient.get_job_detail(job_id)
  398. print(job_detail)
  399. statue = job_detail['Status']
  400. # Initialized: 初始化完成 Starting:开始 WorkflowServiceStarting:准备提交 Running:运行中 ReadyToSchedule:准备运行(前序节点未完成导致)
  401. if (statue == 'Initialized' or statue == 'Starting' or statue == 'WorkflowServiceStarting'
  402. or statue == 'Running' or statue == 'ReadyToSchedule'):
  403. time.sleep(check_interval)
  404. continue
  405. # Failed:运行失败 Terminating:终止中 Terminated:已终止 Unknown:未知 Skipped:跳过(前序节点失败导致) Succeeded:运行成功
  406. if statue == 'Failed' or statue == 'Terminating' or statue == 'Unknown' or statue == 'Skipped' or statue == 'Succeeded':
  407. return job_detail
  408. def get_node_dict():
  409. draft = PAIClient.get_work_flow_draft(experiment_id)
  410. content = draft['Content']
  411. content_json = json.loads(content)
  412. nodes = content_json.get('nodes')
  413. node_dict = {}
  414. for node in nodes:
  415. name = node['name']
  416. # 检查名称是否在目标名称集合中
  417. if name in target_names:
  418. node_dict[name] = node['id']
  419. return node_dict
  420. def get_job_dict():
  421. job_dict = {}
  422. jobs_list = PAIClient.get_jobs_list(experiment_id)
  423. for job in jobs_list['Jobs']:
  424. # 解析时间字符串为 datetime 对象
  425. if not compare_timestamp_with_today_start(job['GmtCreateTime']):
  426. break
  427. job_id = job['JobId']
  428. job_detail = PAIClient.get_job_detail(job_id, verbose=True)
  429. for name in target_names:
  430. if job_detail['Status'] != 'Succeeded':
  431. continue
  432. if name in job_dict:
  433. continue
  434. if name in job_detail['RunInfo']:
  435. job_dict[name] = job_detail['JobId']
  436. return job_dict
  437. @retry
  438. def update_online_flow():
  439. try:
  440. online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
  441. draft = PAIClient.get_work_flow_draft(experiment_id)
  442. print(json.dumps(draft, ensure_ascii=False))
  443. content = draft['Content']
  444. version = draft['Version']
  445. print(content)
  446. content_json = json.loads(content)
  447. nodes = content_json.get('nodes')
  448. global_params = content_json.get('globalParams')
  449. bizdate = get_previous_days_date(1)
  450. for global_param in global_params:
  451. try:
  452. if global_param['name'] == 'bizdate':
  453. global_param['value'] = bizdate
  454. if global_param['name'] == 'online_version_dt':
  455. global_param['value'] = online_model_config['online_date']
  456. if global_param['name'] == 'eval_date':
  457. global_param['value'] = bizdate
  458. if global_param['name'] == 'online_model_path':
  459. global_param['value'] = online_model_config['model_path']
  460. except KeyError:
  461. raise Exception("在处理全局参数时,字典中缺少必要的键")
  462. for node in nodes:
  463. try:
  464. name = node['name']
  465. if name in ('样本shuffle', '评估样本重组'):
  466. properties = node['properties']
  467. for property in properties:
  468. if property['name'] == 'sql':
  469. value = property['value']
  470. new_value = update_data_date_range(value)
  471. if new_value is None:
  472. print("error")
  473. property['value'] = new_value
  474. except KeyError:
  475. raise Exception("在处理节点属性时,字典中缺少必要的键")
  476. new_content = json.dumps(content_json, ensure_ascii=False)
  477. PAIClient.update_experiment_content(experiment_id, new_content, version)
  478. return True
  479. except json.JSONDecodeError:
  480. raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
  481. except Exception as e:
  482. raise Exception(f"发生未知错误: {e}")
  483. def update_global_param(params):
  484. try:
  485. draft = PAIClient.get_work_flow_draft(experiment_id)
  486. content = draft['Content']
  487. version = draft['Version']
  488. content_json = json.loads(content)
  489. nodes = content_json.get('nodes')
  490. global_params = content_json.get('globalParams')
  491. for global_param in global_params:
  492. if global_param['name'] in params:
  493. value = params[global_param['name']]
  494. print(f"update global param {global_param['name']}: {value}")
  495. global_param['value'] = value
  496. new_content = json.dumps(content_json, ensure_ascii=False)
  497. PAIClient.update_experiment_content(experiment_id, new_content, version)
  498. return True
  499. except json.JSONDecodeError:
  500. raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
  501. except Exception as e:
  502. raise Exception(f"发生未知错误: {e}")
  503. @retry
  504. def shuffle_table():
  505. try:
  506. node_dict = get_node_dict()
  507. train_node_id = node_dict['样本shuffle']
  508. execute_type = 'EXECUTE_FROM_HERE'
  509. validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  510. validate_job_id = validate_res['JobId']
  511. validate_job_detail = wait_job_end(validate_job_id, 10)
  512. if validate_job_detail['Status'] == 'Succeeded':
  513. return True
  514. return False
  515. except Exception as e:
  516. error_message = f"在执行 shuffle_table 函数时发生异常: {str(e)}"
  517. print(error_message)
  518. raise Exception(error_message)
  519. @retry
  520. def shuffle_train_model():
  521. try:
  522. node_dict = get_node_dict()
  523. job_dict = get_job_dict()
  524. job_id = job_dict['样本shuffle']
  525. validate_job_detail = wait_job_end(job_id)
  526. if validate_job_detail['Status'] == 'Succeeded':
  527. pipeline_run_id = validate_job_detail['RunId']
  528. node_id = validate_job_detail['PaiflowNodeId']
  529. flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
  530. outputs = flow_out_put_detail['Outputs']
  531. table = None
  532. for output in outputs:
  533. if output["Producer"] == node_dict['样本shuffle'] and output["Name"] == "outputTable":
  534. value1 = json.loads(output["Info"]['value'])
  535. table = value1['location']['table']
  536. if table is not None:
  537. update_shuffle_flow(table)
  538. node_dict = get_node_dict()
  539. train_node_id = node_dict['模型训练-样本shufle']
  540. execute_type = 'EXECUTE_ONE'
  541. train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  542. train_job_id = train_res['JobId']
  543. train_job_detail = wait_job_end(train_job_id)
  544. if train_job_detail['Status'] == 'Succeeded':
  545. return True
  546. return False
  547. except Exception as e:
  548. error_message = f"在执行 shuffle_train_model 函数时发生异常: {str(e)}"
  549. print(error_message)
  550. raise Exception(error_message)
  551. @retry
  552. def export_model():
  553. try:
  554. node_dict = get_node_dict()
  555. export_node_id = node_dict['模型导出-2']
  556. execute_type = 'EXECUTE_ONE'
  557. export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
  558. export_job_id = export_res['JobId']
  559. export_job_detail = wait_job_end(export_job_id)
  560. if export_job_detail['Status'] == 'Succeeded':
  561. return True
  562. return False
  563. except Exception as e:
  564. error_message = f"在执行 export_model 函数时发生异常: {str(e)}"
  565. print(error_message)
  566. raise Exception(error_message)
  567. def update_online_model():
  568. try:
  569. node_dict = get_node_dict()
  570. train_node_id = node_dict['更新EAS服务(Beta)-1']
  571. execute_type = 'EXECUTE_ONE'
  572. train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  573. train_job_id = train_res['JobId']
  574. train_job_detail = wait_job_end(train_job_id)
  575. if train_job_detail['Status'] == 'Succeeded':
  576. return True
  577. return False
  578. except Exception as e:
  579. error_message = f"在执行 update_online_model 函数时发生异常: {str(e)}"
  580. print(error_message)
  581. raise Exception(error_message)
  582. def update_validation_config():
  583. try:
  584. job_dict = get_job_dict()
  585. node_dict = get_node_dict()
  586. print(node_dict)
  587. job_id = job_dict['样本shuffle']
  588. validate_job_detail = wait_job_end(job_id)
  589. table = None
  590. if validate_job_detail['Status'] == 'Succeeded':
  591. pipeline_run_id = validate_job_detail['RunId']
  592. node_id = validate_job_detail['PaiflowNodeId']
  593. flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
  594. outputs = flow_out_put_detail['Outputs']
  595. for output in outputs:
  596. if output["Producer"] == node_dict['评估样本重组'] and output["Name"] == "outputTable":
  597. value1 = json.loads(output["Info"]['value'])
  598. table = value1['location']['table']
  599. if not table:
  600. raise Exception("table not available")
  601. update_global_param({'eval_table_name': table})
  602. except Exception as e:
  603. error_message = f"在执行 update_validation_config 函数时发生异常: {str(e)}"
  604. print(error_message)
  605. raise Exception(error_message)
  606. @retry
  607. def get_validate_model_data():
  608. update_validation_config()
  609. try:
  610. node_dict = get_node_dict()
  611. train_node_id = node_dict['虚拟起始节点']
  612. execute_type = 'EXECUTE_FROM_HERE'
  613. validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  614. validate_job_id = validate_res['JobId']
  615. validate_job_detail = wait_job_end(validate_job_id)
  616. if validate_job_detail['Status'] == 'Succeeded':
  617. return True
  618. return False
  619. except Exception as e:
  620. error_message = f"在执行 get_validate_model_data 函数时出现异常: {e}"
  621. print(error_message)
  622. raise Exception(error_message)
  623. def validate_model_data_accuracy():
  624. try:
  625. table_dict = {}
  626. node_dict = get_node_dict()
  627. job_dict = get_job_dict()
  628. job_id = job_dict['虚拟起始节点']
  629. validate_job_detail = wait_job_end(job_id)
  630. if validate_job_detail['Status'] == 'Succeeded':
  631. pipeline_run_id = validate_job_detail['RunId']
  632. node_id = validate_job_detail['PaiflowNodeId']
  633. flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
  634. print(flow_out_put_detail)
  635. outputs = flow_out_put_detail['Outputs']
  636. for output in outputs:
  637. if output["Producer"] == node_dict['二分类评估-1'] and output["Name"] == "outputMetricTable":
  638. value1 = json.loads(output["Info"]['value'])
  639. table_dict['二分类评估-1'] = value1['location']['table']
  640. if output["Producer"] == node_dict['二分类评估-2'] and output["Name"] == "outputMetricTable":
  641. value2 = json.loads(output["Info"]['value'])
  642. table_dict['二分类评估-2'] = value2['location']['table']
  643. if output["Producer"] == node_dict['预测结果对比'] and output["Name"] == "outputTable":
  644. value3 = json.loads(output["Info"]['value'])
  645. table_dict['预测结果对比'] = value3['location']['table']
  646. num = 10
  647. df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)
  648. # 对指定列取绝对值再求和
  649. old_abs_avg = df['old_error'].abs().sum() / num
  650. new_abs_avg = df['new_error'].abs().sum() / num
  651. new_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-1'])['AUC']
  652. old_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-2'])['AUC']
  653. bizdate = get_previous_days_date(1)
  654. score_diff = abs(old_abs_avg - new_abs_avg)
  655. msg = ""
  656. result = False
  657. if new_abs_avg > 0.1:
  658. msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
  659. level = 'error'
  660. elif score_diff > 0.05:
  661. msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
  662. level = 'error'
  663. else:
  664. msg += 'DNN广告模型更新完成'
  665. level = 'info'
  666. result = True
  667. # 初始化表格头部
  668. top10_msg = "| CID | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
  669. top10_msg += "\n| ---- | --------- | -------- |"
  670. for index, row in df.iterrows():
  671. # 获取指定列的元素
  672. cid = row['cid']
  673. old_error = row['old_error']
  674. new_error = row['new_error']
  675. top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
  676. print(top10_msg)
  677. msg += f"\n\t - 老模型AUC: {old_auc}"
  678. msg += f"\n\t - 新模型AUC: {new_auc}"
  679. msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
  680. msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
  681. return result, msg, level, top10_msg
  682. except Exception as e:
  683. error_message = f"在执行 validate_model_data_accuracy 函数时出现异常: {str(e)}"
  684. print(error_message)
  685. raise Exception(error_message)
  686. if __name__ == '__main__':
  687. start_time = int(time.time())
  688. functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
  689. function_names = [func.__name__ for func in functions]
  690. start_function = None
  691. if len(sys.argv) > 1:
  692. start_function = sys.argv[1]
  693. if start_function not in function_names:
  694. print(f"指定的起始函数 {start_function} 不存在,请选择以下函数之一:{', '.join(function_names)}")
  695. sys.exit(1)
  696. start_index = 0
  697. if start_function:
  698. start_index = function_names.index(start_function)
  699. for func in functions[start_index:]:
  700. if not func():
  701. print(f"{func.__name__} 执行失败,后续函数不再执行。")
  702. step_end_time = int(time.time())
  703. elapsed = step_end_time - start_time
  704. _monitor('error', f"DNN模型更新,{func.__name__} 执行失败,后续函数不再执行,请检查", start_time, elapsed, None)
  705. break
  706. else:
  707. print("所有函数都成功执行,可以继续下一步操作。")
  708. result, msg, level, top10_msg = validate_model_data_accuracy()
  709. if result:
  710. # update_online_model()
  711. print("success")
  712. step_end_time = int(time.time())
  713. elapsed = step_end_time - start_time
  714. print(level, msg, start_time, elapsed, top10_msg)
  715. _monitor(level, msg, start_time, elapsed, top10_msg)