pai_flow_operator2.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import re
  4. import sys
  5. import time
  6. import json
  7. import pandas as pd
  8. from alibabacloud_paistudio20210202.client import Client as PaiStudio20210202Client
  9. from alibabacloud_tea_openapi import models as open_api_models
  10. from alibabacloud_paistudio20210202 import models as pai_studio_20210202_models
  11. from alibabacloud_tea_util import models as util_models
  12. from alibabacloud_tea_util.client import Client as UtilClient
  13. from alibabacloud_eas20210701.client import Client as eas20210701Client
  14. from alibabacloud_paiflow20210202 import models as paiflow_20210202_models
  15. from alibabacloud_paiflow20210202.client import Client as PAIFlow20210202Client
  16. from datetime import datetime, timedelta
  17. from odps import ODPS
  18. from ad_monitor_util import _monitor
  19. target_names = {
  20. '样本shuffle',
  21. '模型训练-样本shufle',
  22. '模型训练-自定义',
  23. '模型增量训练',
  24. '模型导出-2',
  25. '更新EAS服务(Beta)-1',
  26. '虚拟起始节点',
  27. '二分类评估-1',
  28. '二分类评估-2',
  29. '预测结果对比'
  30. }
  31. experiment_id = "draft-wqgkag89sbh9v1zvut"
  32. MAX_RETRIES = 3
  33. def retry(func):
  34. def wrapper(*args, **kwargs):
  35. retries = 0
  36. while retries < MAX_RETRIES:
  37. try:
  38. result = func(*args, **kwargs)
  39. if result is not False:
  40. return result
  41. except Exception as e:
  42. print(f"函数 {func.__name__} 执行时发生异常: {e},重试第 {retries + 1} 次")
  43. retries += 1
  44. print(f"函数 {func.__name__} 重试 {MAX_RETRIES} 次后仍失败。")
  45. return False
  46. return wrapper
  47. def get_odps_instance(project):
  48. odps = ODPS(
  49. access_id='LTAIWYUujJAm7CbH',
  50. secret_access_key='RfSjdiWwED1sGFlsjXv0DlfTnZTG1P',
  51. project=project,
  52. endpoint='http://service.cn.maxcompute.aliyun.com/api',
  53. )
  54. return odps
  55. def get_data_from_odps(project, table, num):
  56. odps = get_odps_instance(project)
  57. try:
  58. # 要查询的 SQL 语句
  59. sql = f'select * from {table} limit {num}'
  60. # 执行 SQL 查询
  61. with odps.execute_sql(sql).open_reader() as reader:
  62. df = reader.to_pandas()
  63. # 查询数量小于目标数量时 返回空
  64. if len(df) < num:
  65. return None
  66. return df
  67. except Exception as e:
  68. print(f"发生错误: {e}")
  69. def get_dict_from_odps(project, table):
  70. odps = get_odps_instance(project)
  71. try:
  72. # 要查询的 SQL 语句
  73. sql = f'select * from {table}'
  74. # 执行 SQL 查询
  75. with odps.execute_sql(sql).open_reader() as reader:
  76. data = {}
  77. for record in reader:
  78. record_list = list(record)
  79. key = record_list[0][1]
  80. value = record_list[1][1]
  81. data[key] = value
  82. return data
  83. except Exception as e:
  84. print(f"发生错误: {e}")
  85. def get_dates_between(start_date_str, end_date_str):
  86. start_date = datetime.strptime(start_date_str, '%Y%m%d')
  87. end_date = datetime.strptime(end_date_str, '%Y%m%d')
  88. dates = []
  89. current_date = start_date
  90. while current_date <= end_date:
  91. dates.append(current_date.strftime('%Y%m%d'))
  92. current_date += timedelta(days=1)
  93. return dates
  94. def read_file_to_list():
  95. try:
  96. current_dir = os.getcwd()
  97. file_path = os.path.join(current_dir, 'ad', 'holidays.txt')
  98. with open(file_path, 'r', encoding='utf-8') as file:
  99. content = file.read()
  100. return content.split('\n')
  101. except FileNotFoundError:
  102. raise Exception(f"错误:未找到 {file_path} 文件。")
  103. except Exception as e:
  104. raise Exception(f"错误:发生了一个未知错误: {e}")
  105. return []
  106. def get_previous_days_date(days):
  107. current_date = datetime.now()
  108. previous_date = current_date - timedelta(days=days)
  109. return previous_date.strftime('%Y%m%d')
  110. def remove_elements(lst1, lst2):
  111. return [element for element in lst1 if element not in lst2]
  112. def process_list(lst, append_str):
  113. # 给列表中每个元素拼接相同的字符串
  114. appended_list = [append_str + element for element in lst]
  115. # 将拼接后的列表元素用逗号拼接成一个字符串
  116. result_str = ','.join(appended_list)
  117. return result_str
  118. def get_train_data_list():
  119. start_date = '20250223'
  120. end_date = get_previous_days_date(2)
  121. date_list = get_dates_between(start_date, end_date)
  122. filter_date_list = read_file_to_list()
  123. date_list = remove_elements(date_list, filter_date_list)
  124. return date_list
  125. def update_train_tables(old_str):
  126. date_list = get_train_data_list()
  127. train_list = ["'" + item + "'" for item in date_list]
  128. result = ','.join(train_list)
  129. start_index = old_str.find('where dt in (')
  130. if start_index != -1:
  131. equal_sign_index = start_index + len('where dt in (')
  132. # 找到下一个双引号的位置
  133. next_quote_index = old_str.find(')', equal_sign_index)
  134. if next_quote_index != -1:
  135. # 进行替换
  136. new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
  137. return new_value
  138. return None
  139. def compare_timestamp_with_today_start(time_str):
  140. # 解析时间字符串为 datetime 对象
  141. time_obj = datetime.fromisoformat(time_str)
  142. # 将其转换为时间戳
  143. target_timestamp = time_obj.timestamp()
  144. # 获取今天开始的时间
  145. today_start = datetime.combine(datetime.now().date(), datetime.min.time())
  146. # 将今天开始时间转换为时间戳
  147. today_start_timestamp = today_start.timestamp()
  148. return target_timestamp > today_start_timestamp
  149. def update_train_table(old_str, table):
  150. address = 'odps://pai_algo/tables/'
  151. train_table = address + table
  152. start_index = old_str.find('-Dtrain_tables="')
  153. if start_index != -1:
  154. # 确定等号的位置
  155. equal_sign_index = start_index + len('-Dtrain_tables="')
  156. # 找到下一个双引号的位置
  157. next_quote_index = old_str.find('"', equal_sign_index)
  158. if next_quote_index != -1:
  159. # 进行替换
  160. new_value = old_str[:equal_sign_index] + train_table + old_str[next_quote_index:]
  161. return new_value
  162. return None
  163. class PAIClient:
  164. def __init__(self):
  165. pass
  166. @staticmethod
  167. def create_client() -> PaiStudio20210202Client:
  168. """
  169. 使用AK&SK初始化账号Client
  170. @return: Client
  171. @throws Exception
  172. """
  173. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  174. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  175. config = open_api_models.Config(
  176. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  177. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  178. )
  179. # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
  180. config.endpoint = f'pai.cn-hangzhou.aliyuncs.com'
  181. return PaiStudio20210202Client(config)
  182. @staticmethod
  183. def create_eas_client() -> eas20210701Client:
  184. """
  185. 使用AK&SK初始化账号Client
  186. @return: Client
  187. @throws Exception
  188. """
  189. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  190. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  191. config = open_api_models.Config(
  192. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  193. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  194. )
  195. # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
  196. config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
  197. return eas20210701Client(config)
  198. @staticmethod
  199. def create_flow_client() -> PAIFlow20210202Client:
  200. """
  201. 使用AK&SK初始化账号Client
  202. @return: Client
  203. @throws Exception
  204. """
  205. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  206. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  207. config = open_api_models.Config(
  208. # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_ID。,
  209. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  210. # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_SECRET。,
  211. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  212. )
  213. # Endpoint 请参考 https://api.aliyun.com/product/PAIFlow
  214. config.endpoint = f'paiflow.cn-hangzhou.aliyuncs.com'
  215. return PAIFlow20210202Client(config)
  216. @staticmethod
  217. def get_work_flow_draft_list(workspace_id: str):
  218. client = PAIClient.create_client()
  219. list_experiments_request = pai_studio_20210202_models.ListExperimentsRequest(
  220. workspace_id=workspace_id
  221. )
  222. runtime = util_models.RuntimeOptions()
  223. headers = {}
  224. try:
  225. resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
  226. return resp.body.to_map()
  227. except Exception as error:
  228. raise Exception(f"get_work_flow_draft_list error {error}")
  229. @staticmethod
  230. def get_work_flow_draft(experiment_id: str):
  231. client = PAIClient.create_client()
  232. runtime = util_models.RuntimeOptions()
  233. headers = {}
  234. try:
  235. # 复制代码运行请自行打印 API 的返回值
  236. resp = client.get_experiment_with_options(experiment_id, headers, runtime)
  237. return resp.body.to_map()
  238. except Exception as error:
  239. raise Exception(f"get_work_flow_draft error {error}")
  240. @staticmethod
  241. def get_describe_service(service_name: str):
  242. client = PAIClient.create_eas_client()
  243. runtime = util_models.RuntimeOptions()
  244. headers = {}
  245. try:
  246. # 复制代码运行请自行打印 API 的返回值
  247. resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
  248. return resp.body.to_map()
  249. except Exception as error:
  250. raise Exception(f"get_describe_service error {error}")
  251. @staticmethod
  252. def update_experiment_content(experiment_id: str, content: str, version: int):
  253. client = PAIClient.create_client()
  254. update_experiment_content_request = pai_studio_20210202_models.UpdateExperimentContentRequest(content=content,
  255. version=version)
  256. runtime = util_models.RuntimeOptions()
  257. headers = {}
  258. try:
  259. # 复制代码运行请自行打印 API 的返回值
  260. resp = client.update_experiment_content_with_options(experiment_id, update_experiment_content_request,
  261. headers, runtime)
  262. print(resp.body.to_map())
  263. except Exception as error:
  264. raise Exception(f"update_experiment_content error {error}")
  265. @staticmethod
  266. def create_job(experiment_id: str, node_id: str, execute_type: str):
  267. client = PAIClient.create_client()
  268. create_job_request = pai_studio_20210202_models.CreateJobRequest()
  269. create_job_request.experiment_id = experiment_id
  270. create_job_request.node_id = node_id
  271. create_job_request.execute_type = execute_type
  272. runtime = util_models.RuntimeOptions()
  273. headers = {}
  274. try:
  275. # 复制代码运行请自行打印 API 的返回值
  276. resp = client.create_job_with_options(create_job_request, headers, runtime)
  277. return resp.body.to_map()
  278. except Exception as error:
  279. raise Exception(f"create_job error {error}")
  280. @staticmethod
  281. def get_jobs_list(experiment_id: str, order='DESC'):
  282. client = PAIClient.create_client()
  283. list_jobs_request = pai_studio_20210202_models.ListJobsRequest(
  284. experiment_id=experiment_id,
  285. order=order
  286. )
  287. runtime = util_models.RuntimeOptions()
  288. headers = {}
  289. try:
  290. # 复制代码运行请自行打印 API 的返回值
  291. resp = client.list_jobs_with_options(list_jobs_request, headers, runtime)
  292. return resp.body.to_map()
  293. except Exception as error:
  294. raise Exception(f"get_jobs_list error {error}")
  295. @staticmethod
  296. def get_job_detail(job_id: str, verbose=False):
  297. client = PAIClient.create_client()
  298. get_job_request = pai_studio_20210202_models.GetJobRequest(
  299. verbose=verbose
  300. )
  301. runtime = util_models.RuntimeOptions()
  302. headers = {}
  303. try:
  304. # 复制代码运行请自行打印 API 的返回值
  305. resp = client.get_job_with_options(job_id, get_job_request, headers, runtime)
  306. return resp.body.to_map()
  307. except Exception as error:
  308. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  309. # 错误 message
  310. print(error.message)
  311. # 诊断地址
  312. print(error.data.get("Recommend"))
  313. UtilClient.assert_as_string(error.message)
  314. @staticmethod
  315. def get_flow_out_put(pipeline_run_id: str, node_id: str, depth: int):
  316. client = PAIClient.create_flow_client()
  317. list_pipeline_run_node_outputs_request = paiflow_20210202_models.ListPipelineRunNodeOutputsRequest(
  318. depth=depth
  319. )
  320. runtime = util_models.RuntimeOptions()
  321. headers = {}
  322. try:
  323. # 复制代码运行请自行打印 API 的返回值
  324. resp = client.list_pipeline_run_node_outputs_with_options(pipeline_run_id, node_id,
  325. list_pipeline_run_node_outputs_request, headers,
  326. runtime)
  327. return resp.body.to_map()
  328. except Exception as error:
  329. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  330. # 错误 message
  331. print(error.message)
  332. # 诊断地址
  333. print(error.data.get("Recommend"))
  334. UtilClient.assert_as_string(error.message)
  335. def extract_date_yyyymmdd(input_string):
  336. pattern = r'\d{8}'
  337. matches = re.findall(pattern, input_string)
  338. if matches:
  339. return matches[0]
  340. return None
  341. def get_online_version_dt(service_name: str):
  342. model_detail = PAIClient.get_describe_service(service_name)
  343. service_config_str = model_detail['ServiceConfig']
  344. service_config = json.loads(service_config_str)
  345. model_path = service_config['model_path']
  346. online_date = extract_date_yyyymmdd(model_path)
  347. return online_date
  348. def update_shuffle_flow(table):
  349. draft = PAIClient.get_work_flow_draft(experiment_id)
  350. print(json.dumps(draft, ensure_ascii=False))
  351. content = draft['Content']
  352. version = draft['Version']
  353. content_json = json.loads(content)
  354. nodes = content_json.get('nodes')
  355. for node in nodes:
  356. name = node['name']
  357. if name == '模型训练-样本shufle':
  358. properties = node['properties']
  359. for property in properties:
  360. if property['name'] == 'sql':
  361. value = property['value']
  362. new_value = update_train_table(value, table)
  363. if new_value is None:
  364. print("error")
  365. property['value'] = new_value
  366. new_content = json.dumps(content_json, ensure_ascii=False)
  367. PAIClient.update_experiment_content(experiment_id, new_content, version)
  368. def update_shuffle_flow_1():
  369. draft = PAIClient.get_work_flow_draft(experiment_id)
  370. print(json.dumps(draft, ensure_ascii=False))
  371. content = draft['Content']
  372. version = draft['Version']
  373. print(content)
  374. content_json = json.loads(content)
  375. nodes = content_json.get('nodes')
  376. for node in nodes:
  377. name = node['name']
  378. if name == '模型训练-样本shufle':
  379. properties = node['properties']
  380. for property in properties:
  381. if property['name'] == 'sql':
  382. value = property['value']
  383. new_value = update_train_tables(value)
  384. if new_value is None:
  385. print("error")
  386. property['value'] = new_value
  387. new_content = json.dumps(content_json, ensure_ascii=False)
  388. PAIClient.update_experiment_content(experiment_id, new_content, version)
  389. def wait_job_end(job_id: str):
  390. while True:
  391. job_detail = PAIClient.get_job_detail(job_id)
  392. print(job_detail)
  393. statue = job_detail['Status']
  394. # Initialized: 初始化完成 Starting:开始 WorkflowServiceStarting:准备提交 Running:运行中 ReadyToSchedule:准备运行(前序节点未完成导致)
  395. if (statue == 'Initialized' or statue == 'Starting' or statue == 'WorkflowServiceStarting'
  396. or statue == 'Running' or statue == 'ReadyToSchedule'):
  397. # 睡眠300s 等待下次获取
  398. time.sleep(300)
  399. continue
  400. # Failed:运行失败 Terminating:终止中 Terminated:已终止 Unknown:未知 Skipped:跳过(前序节点失败导致) Succeeded:运行成功
  401. if statue == 'Failed' or statue == 'Terminating' or statue == 'Unknown' or statue == 'Skipped' or statue == 'Succeeded':
  402. return job_detail
  403. def get_node_dict():
  404. draft = PAIClient.get_work_flow_draft(experiment_id)
  405. content = draft['Content']
  406. content_json = json.loads(content)
  407. nodes = content_json.get('nodes')
  408. node_dict = {}
  409. for node in nodes:
  410. name = node['name']
  411. # 检查名称是否在目标名称集合中
  412. if name in target_names:
  413. node_dict[name] = node['id']
  414. return node_dict
  415. def get_job_dict():
  416. job_dict = {}
  417. jobs_list = PAIClient.get_jobs_list(experiment_id)
  418. for job in jobs_list['Jobs']:
  419. # 解析时间字符串为 datetime 对象
  420. if not compare_timestamp_with_today_start(job['GmtCreateTime']):
  421. break
  422. job_id = job['JobId']
  423. job_detail = PAIClient.get_job_detail(job_id, verbose=True)
  424. for name in target_names:
  425. if job_detail['Status'] != 'Succeeded':
  426. continue
  427. if name in job_dict:
  428. continue
  429. if name in job_detail['RunInfo']:
  430. job_dict[name] = job_detail['JobId']
  431. return job_dict
  432. @retry
  433. def update_online_flow():
  434. try:
  435. online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
  436. draft = PAIClient.get_work_flow_draft(experiment_id)
  437. print(json.dumps(draft, ensure_ascii=False))
  438. content = draft['Content']
  439. version = draft['Version']
  440. print(content)
  441. content_json = json.loads(content)
  442. nodes = content_json.get('nodes')
  443. global_params = content_json.get('globalParams')
  444. bizdate = get_previous_days_date(1)
  445. for global_param in global_params:
  446. try:
  447. if global_param['name'] == 'bizdate':
  448. global_param['value'] = bizdate
  449. if global_param['name'] == 'online_version_dt':
  450. global_param['value'] = online_version_dt
  451. if global_param['name'] == 'eval_date':
  452. global_param['value'] = bizdate
  453. except KeyError:
  454. raise Exception("在处理全局参数时,字典中缺少必要的键")
  455. for node in nodes:
  456. try:
  457. name = node['name']
  458. if name == '样本shuffle':
  459. properties = node['properties']
  460. for property in properties:
  461. if property['name'] == 'sql':
  462. value = property['value']
  463. new_value = update_train_tables(value)
  464. if new_value is None:
  465. print("error")
  466. property['value'] = new_value
  467. except KeyError:
  468. raise Exception("在处理节点属性时,字典中缺少必要的键")
  469. new_content = json.dumps(content_json, ensure_ascii=False)
  470. PAIClient.update_experiment_content(experiment_id, new_content, version)
  471. return True
  472. except json.JSONDecodeError:
  473. raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
  474. except Exception as e:
  475. raise Exception(f"发生未知错误: {e}")
  476. @retry
  477. def shuffle_table():
  478. try:
  479. node_dict = get_node_dict()
  480. train_node_id = node_dict['样本shuffle']
  481. execute_type = 'EXECUTE_ONE'
  482. validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  483. validate_job_id = validate_res['JobId']
  484. validate_job_detail = wait_job_end(validate_job_id)
  485. if validate_job_detail['Status'] == 'Succeeded':
  486. return True
  487. return False
  488. except Exception as e:
  489. error_message = f"在执行 shuffle_table 函数时发生异常: {str(e)}"
  490. print(error_message)
  491. raise Exception(error_message)
  492. @retry
  493. def shuffle_train_model():
  494. try:
  495. node_dict = get_node_dict()
  496. job_dict = get_job_dict()
  497. job_id = job_dict['样本shuffle']
  498. validate_job_detail = wait_job_end(job_id)
  499. if validate_job_detail['Status'] == 'Succeeded':
  500. pipeline_run_id = validate_job_detail['RunId']
  501. node_id = validate_job_detail['PaiflowNodeId']
  502. flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
  503. out_puts = flow_out_put_detail['Outputs']
  504. table = None
  505. for out_put in out_puts:
  506. if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
  507. value1 = json.loads(out_put["Info"]['value'])
  508. table = value1['location']['table']
  509. if table is not None:
  510. update_shuffle_flow(table)
  511. node_dict = get_node_dict()
  512. train_node_id = node_dict['模型训练-样本shufle']
  513. execute_type = 'EXECUTE_ONE'
  514. train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  515. train_job_id = train_res['JobId']
  516. train_job_detail = wait_job_end(train_job_id)
  517. if train_job_detail['Status'] == 'Succeeded':
  518. return True
  519. return False
  520. except Exception as e:
  521. error_message = f"在执行 shuffle_train_model 函数时发生异常: {str(e)}"
  522. print(error_message)
  523. raise Exception(error_message)
  524. @retry
  525. def export_model():
  526. try:
  527. node_dict = get_node_dict()
  528. export_node_id = node_dict['模型导出-2']
  529. execute_type = 'EXECUTE_ONE'
  530. export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
  531. export_job_id = export_res['JobId']
  532. export_job_detail = wait_job_end(export_job_id)
  533. if export_job_detail['Status'] == 'Succeeded':
  534. return True
  535. return False
  536. except Exception as e:
  537. error_message = f"在执行 export_model 函数时发生异常: {str(e)}"
  538. print(error_message)
  539. raise Exception(error_message)
  540. def update_online_model():
  541. try:
  542. node_dict = get_node_dict()
  543. train_node_id = node_dict['更新EAS服务(Beta)-1']
  544. execute_type = 'EXECUTE_ONE'
  545. train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  546. train_job_id = train_res['JobId']
  547. train_job_detail = wait_job_end(train_job_id)
  548. if train_job_detail['Status'] == 'Succeeded':
  549. return True
  550. return False
  551. except Exception as e:
  552. error_message = f"在执行 update_online_model 函数时发生异常: {str(e)}"
  553. print(error_message)
  554. raise Exception(error_message)
  555. @retry
  556. def get_validate_model_data():
  557. try:
  558. node_dict = get_node_dict()
  559. train_node_id = node_dict['虚拟起始节点']
  560. execute_type = 'EXECUTE_FROM_HERE'
  561. validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  562. validate_job_id = validate_res['JobId']
  563. validate_job_detail = wait_job_end(validate_job_id)
  564. if validate_job_detail['Status'] == 'Succeeded':
  565. return True
  566. return False
  567. except Exception as e:
  568. error_message = f"在执行 get_validate_model_data 函数时出现异常: {e}"
  569. print(error_message)
  570. raise Exception(error_message)
  571. def validate_model_data_accuracy():
  572. try:
  573. table_dict = {}
  574. node_dict = get_node_dict()
  575. job_dict = get_job_dict()
  576. job_id = job_dict['虚拟起始节点']
  577. validate_job_detail = wait_job_end(job_id)
  578. if validate_job_detail['Status'] == 'Succeeded':
  579. pipeline_run_id = validate_job_detail['RunId']
  580. node_id = validate_job_detail['PaiflowNodeId']
  581. flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
  582. print(flow_out_put_detail)
  583. out_puts = flow_out_put_detail['Outputs']
  584. for out_put in out_puts:
  585. if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
  586. value1 = json.loads(out_put["Info"]['value'])
  587. table_dict['二分类评估-1'] = value1['location']['table']
  588. if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
  589. value2 = json.loads(out_put["Info"]['value'])
  590. table_dict['二分类评估-2'] = value2['location']['table']
  591. if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
  592. value3 = json.loads(out_put["Info"]['value'])
  593. table_dict['预测结果对比'] = value3['location']['table']
  594. num = 10
  595. df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)
  596. # 对指定列取绝对值再求和
  597. old_abs_avg = df['old_error'].abs().sum() / num
  598. new_abs_avg = df['new_error'].abs().sum() / num
  599. new_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-1'])['AUC']
  600. old_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-2'])['AUC']
  601. bizdate = get_previous_days_date(1)
  602. score_diff = abs(old_abs_avg - new_abs_avg)
  603. msg = ""
  604. result = False
  605. if new_abs_avg > 0.1:
  606. msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
  607. level = 'error'
  608. elif score_diff > 0.05:
  609. msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
  610. level = 'error'
  611. else:
  612. msg += 'DNN广告模型更新完成'
  613. level = 'info'
  614. result = True
  615. # 初始化表格头部
  616. top10_msg = "| CID | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
  617. top10_msg += "\n| ---- | --------- | -------- |"
  618. for index, row in df.iterrows():
  619. # 获取指定列的元素
  620. cid = row['cid']
  621. old_error = row['old_error']
  622. new_error = row['new_error']
  623. top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
  624. print(top10_msg)
  625. msg += f"\n\t - 老模型AUC: {old_auc}"
  626. msg += f"\n\t - 新模型AUC: {new_auc}"
  627. msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
  628. msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
  629. return result, msg, level, top10_msg
  630. except Exception as e:
  631. error_message = f"在执行 validate_model_data_accuracy 函数时出现异常: {str(e)}"
  632. print(error_message)
  633. raise Exception(error_message)
  634. if __name__ == '__main__':
  635. start_time = int(time.time())
  636. functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
  637. function_names = [func.__name__ for func in functions]
  638. start_function = None
  639. if len(sys.argv) > 1:
  640. start_function = sys.argv[1]
  641. if start_function not in function_names:
  642. print(f"指定的起始函数 {start_function} 不存在,请选择以下函数之一:{', '.join(function_names)}")
  643. sys.exit(1)
  644. start_index = 0
  645. if start_function:
  646. start_index = function_names.index(start_function)
  647. for func in functions[start_index:]:
  648. if not func():
  649. print(f"{func.__name__} 执行失败,后续函数不再执行。")
  650. step_end_time = int(time.time())
  651. elapsed = step_end_time - start_time
  652. _monitor('error', f"DNN模型更新,{func.__name__} 执行失败,后续函数不再执行,请检查", start_time, elapsed, None)
  653. break
  654. else:
  655. print("所有函数都成功执行,可以继续下一步操作。")
  656. result, msg, level, top10_msg = validate_model_data_accuracy()
  657. if result:
  658. # update_online_model()
  659. print("success")
  660. step_end_time = int(time.time())
  661. elapsed = step_end_time - start_time
  662. print(level, msg, start_time, elapsed, top10_msg)
  663. _monitor(level, msg, start_time, elapsed, top10_msg)