pai_flow_operator.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import re
  4. import sys
  5. from typing import List
  6. import time
  7. import json
  8. import pandas as pd
  9. from alibabacloud_paistudio20210202.client import Client as PaiStudio20210202Client
  10. from alibabacloud_tea_openapi import models as open_api_models
  11. from alibabacloud_paistudio20210202 import models as pai_studio_20210202_models
  12. from alibabacloud_tea_util import models as util_models
  13. from alibabacloud_tea_util.client import Client as UtilClient
  14. from alibabacloud_eas20210701.client import Client as eas20210701Client
  15. from alibabacloud_paiflow20210202 import models as paiflow_20210202_models
  16. from alibabacloud_paiflow20210202.client import Client as PAIFlow20210202Client
  17. from datetime import datetime, timedelta
  18. from odps import ODPS
  19. from ad_monitor_util import _monitor
  20. target_names = {
  21. '样本shuffle',
  22. '模型训练-样本shufle',
  23. '模型训练-自定义',
  24. '模型增量训练',
  25. '模型导出-2',
  26. '更新EAS服务(Beta)-1',
  27. '虚拟起始节点',
  28. '二分类评估-1',
  29. '二分类评估-2',
  30. '预测结果对比'
  31. }
  32. experiment_id = "draft-kbezr8f0q3cpee9eqc"
  33. def get_odps_instance(project):
  34. odps = ODPS(
  35. access_id='LTAIWYUujJAm7CbH',
  36. secret_access_key='RfSjdiWwED1sGFlsjXv0DlfTnZTG1P',
  37. project=project,
  38. endpoint='http://service.cn.maxcompute.aliyun.com/api',
  39. )
  40. return odps
  41. def get_data_from_odps(project, table, num):
  42. odps = get_odps_instance(project)
  43. try:
  44. # 要查询的 SQL 语句
  45. sql = f'select * from {table} limit {num}'
  46. # 执行 SQL 查询
  47. with odps.execute_sql(sql).open_reader() as reader:
  48. # 查询数量小于目标数量时 返回空
  49. if reader.count < num:
  50. return None
  51. # 获取字段名称
  52. column_names = reader.schema.names
  53. # 获取查询结果数据
  54. data = []
  55. for record in reader:
  56. record_list = list(record)
  57. numbers = []
  58. for item in record_list:
  59. numbers.append(item[1])
  60. data.append(numbers)
  61. # 将数据和字段名称组合成 DataFrame
  62. df = pd.DataFrame(data, columns=column_names)
  63. return df
  64. except Exception as e:
  65. print(f"发生错误: {e}")
  66. def get_dict_from_odps(project, table):
  67. odps = get_odps_instance(project)
  68. try:
  69. # 要查询的 SQL 语句
  70. sql = f'select * from {table}'
  71. # 执行 SQL 查询
  72. with odps.execute_sql(sql).open_reader() as reader:
  73. data = {}
  74. for record in reader:
  75. record_list = list(record)
  76. key = record_list[0][1]
  77. value = record_list[1][1]
  78. data[key] = value
  79. return data
  80. except Exception as e:
  81. print(f"发生错误: {e}")
  82. def get_dates_between(start_date_str, end_date_str):
  83. start_date = datetime.strptime(start_date_str, '%Y%m%d')
  84. end_date = datetime.strptime(end_date_str, '%Y%m%d')
  85. dates = []
  86. current_date = start_date
  87. while current_date <= end_date:
  88. dates.append(current_date.strftime('%Y%m%d'))
  89. current_date += timedelta(days=1)
  90. return dates
  91. def read_file_to_list():
  92. try:
  93. current_dir = os.getcwd()
  94. file_path = os.path.join(current_dir, 'holidays.txt')
  95. with open(file_path, 'r', encoding='utf-8') as file:
  96. content = file.read()
  97. return content.split('\n')
  98. except FileNotFoundError:
  99. print(f"错误:未找到 {file_path} 文件。")
  100. except Exception as e:
  101. print(f"错误:发生了一个未知错误: {e}")
  102. return []
  103. def get_previous_days_date(days):
  104. current_date = datetime.now()
  105. previous_date = current_date - timedelta(days=days)
  106. return previous_date.strftime('%Y%m%d')
  107. def remove_elements(lst1, lst2):
  108. return [element for element in lst1 if element not in lst2]
  109. def process_list(lst, append_str):
  110. # 给列表中每个元素拼接相同的字符串
  111. appended_list = [append_str + element for element in lst]
  112. # 将拼接后的列表元素用逗号拼接成一个字符串
  113. result_str = ','.join(appended_list)
  114. return result_str
  115. def get_train_data_list():
  116. start_date = '20250223'
  117. end_date = get_previous_days_date(2)
  118. date_list = get_dates_between(start_date, end_date)
  119. filter_date_list = read_file_to_list()
  120. date_list = remove_elements(date_list, filter_date_list)
  121. return date_list
  122. def update_train_tables(old_str):
  123. date_list = get_train_data_list()
  124. address = 'odps://loghubods/tables/ad_easyrec_train_data_v3_sampled/dt='
  125. train_tables = process_list(date_list, address)
  126. start_index = old_str.find('-Dtrain_tables="')
  127. if start_index != -1:
  128. # 确定等号的位置
  129. equal_sign_index = start_index + len('-Dtrain_tables="')
  130. # 找到下一个双引号的位置
  131. next_quote_index = old_str.find('"', equal_sign_index)
  132. if next_quote_index != -1:
  133. # 进行替换
  134. new_value = old_str[:equal_sign_index] + train_tables + old_str[next_quote_index:]
  135. return new_value
  136. return None
  137. def new_update_train_tables(old_str):
  138. date_list = get_train_data_list()
  139. train_list = ["'" + item + "'" for item in date_list]
  140. result = ','.join(train_list)
  141. start_index = old_str.find('where dt in (')
  142. if start_index != -1:
  143. equal_sign_index = start_index + len('where dt in (')
  144. # 找到下一个双引号的位置
  145. next_quote_index = old_str.find(')', equal_sign_index)
  146. if next_quote_index != -1:
  147. # 进行替换
  148. new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
  149. return new_value
  150. return None
  151. class PAIClient:
  152. def __init__(self):
  153. pass
  154. @staticmethod
  155. def create_client() -> PaiStudio20210202Client:
  156. """
  157. 使用AK&SK初始化账号Client
  158. @return: Client
  159. @throws Exception
  160. """
  161. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  162. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  163. config = open_api_models.Config(
  164. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  165. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  166. )
  167. # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
  168. config.endpoint = f'pai.cn-hangzhou.aliyuncs.com'
  169. return PaiStudio20210202Client(config)
  170. @staticmethod
  171. def create_eas_client() -> eas20210701Client:
  172. """
  173. 使用AK&SK初始化账号Client
  174. @return: Client
  175. @throws Exception
  176. """
  177. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  178. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  179. config = open_api_models.Config(
  180. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  181. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  182. )
  183. # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
  184. config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
  185. return eas20210701Client(config)
  186. @staticmethod
  187. def create_flow_client() -> PAIFlow20210202Client:
  188. """
  189. 使用AK&SK初始化账号Client
  190. @return: Client
  191. @throws Exception
  192. """
  193. # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
  194. # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
  195. config = open_api_models.Config(
  196. # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_ID。,
  197. access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
  198. # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_SECRET。,
  199. access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  200. )
  201. # Endpoint 请参考 https://api.aliyun.com/product/PAIFlow
  202. config.endpoint = f'paiflow.cn-hangzhou.aliyuncs.com'
  203. return PAIFlow20210202Client(config)
  204. @staticmethod
  205. def get_work_flow_draft_list(workspace_id: str):
  206. client = PAIClient.create_client()
  207. list_experiments_request = pai_studio_20210202_models.ListExperimentsRequest(
  208. workspace_id=workspace_id
  209. )
  210. runtime = util_models.RuntimeOptions()
  211. headers = {}
  212. try:
  213. resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
  214. return resp.body.to_map()
  215. except Exception as error:
  216. print(error.message)
  217. print(error.data.get("Recommend"))
  218. UtilClient.assert_as_string(error.message)
  219. @staticmethod
  220. def get_work_flow_draft(experiment_id: str):
  221. client = PAIClient.create_client()
  222. runtime = util_models.RuntimeOptions()
  223. headers = {}
  224. try:
  225. # 复制代码运行请自行打印 API 的返回值
  226. resp = client.get_experiment_with_options(experiment_id, headers, runtime)
  227. return resp.body.to_map()
  228. except Exception as error:
  229. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  230. # 错误 message
  231. print(error.message)
  232. # 诊断地址
  233. print(error.data.get("Recommend"))
  234. UtilClient.assert_as_string(error.message)
  235. @staticmethod
  236. def get_describe_service(service_name: str):
  237. client = PAIClient.create_eas_client()
  238. runtime = util_models.RuntimeOptions()
  239. headers = {}
  240. try:
  241. # 复制代码运行请自行打印 API 的返回值
  242. resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
  243. return resp.body.to_map()
  244. except Exception as error:
  245. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  246. # 错误 message
  247. print(error.message)
  248. # 诊断地址
  249. print(error.data.get("Recommend"))
  250. UtilClient.assert_as_string(error.message)
  251. @staticmethod
  252. def update_experiment_content(experiment_id: str, content: str, version: int):
  253. client = PAIClient.create_client()
  254. update_experiment_content_request = pai_studio_20210202_models.UpdateExperimentContentRequest(content=content,
  255. version=version)
  256. runtime = util_models.RuntimeOptions()
  257. headers = {}
  258. try:
  259. # 复制代码运行请自行打印 API 的返回值
  260. resp = client.update_experiment_content_with_options(experiment_id, update_experiment_content_request,
  261. headers, runtime)
  262. print(resp.body.to_map())
  263. except Exception as error:
  264. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  265. # 错误 message
  266. print(error.message)
  267. # 诊断地址
  268. print(error.data.get("Recommend"))
  269. UtilClient.assert_as_string(error.message)
  270. @staticmethod
  271. def create_job(experiment_id: str, node_id: str, execute_type: str):
  272. client = PAIClient.create_client()
  273. create_job_request = pai_studio_20210202_models.CreateJobRequest()
  274. create_job_request.experiment_id = experiment_id
  275. create_job_request.node_id = node_id
  276. create_job_request.execute_type = execute_type
  277. runtime = util_models.RuntimeOptions()
  278. headers = {}
  279. try:
  280. # 复制代码运行请自行打印 API 的返回值
  281. resp = client.create_job_with_options(create_job_request, headers, runtime)
  282. return resp.body.to_map()
  283. except Exception as error:
  284. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  285. # 错误 message
  286. print(error.message)
  287. # 诊断地址
  288. print(error.data.get("Recommend"))
  289. UtilClient.assert_as_string(error.message)
  290. @staticmethod
  291. def get_job_detail(job_id: str):
  292. client = PAIClient.create_client()
  293. get_job_request = pai_studio_20210202_models.GetJobRequest(
  294. verbose=False
  295. )
  296. runtime = util_models.RuntimeOptions()
  297. headers = {}
  298. try:
  299. # 复制代码运行请自行打印 API 的返回值
  300. resp = client.get_job_with_options(job_id, get_job_request, headers, runtime)
  301. return resp.body.to_map()
  302. except Exception as error:
  303. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  304. # 错误 message
  305. print(error.message)
  306. # 诊断地址
  307. print(error.data.get("Recommend"))
  308. UtilClient.assert_as_string(error.message)
  309. @staticmethod
  310. def get_flow_out_put(pipeline_run_id: str, node_id: str, depth: int):
  311. client = PAIClient.create_flow_client()
  312. list_pipeline_run_node_outputs_request = paiflow_20210202_models.ListPipelineRunNodeOutputsRequest(
  313. depth=depth
  314. )
  315. runtime = util_models.RuntimeOptions()
  316. headers = {}
  317. try:
  318. # 复制代码运行请自行打印 API 的返回值
  319. resp = client.list_pipeline_run_node_outputs_with_options(pipeline_run_id, node_id,
  320. list_pipeline_run_node_outputs_request, headers,
  321. runtime)
  322. return resp.body.to_map()
  323. except Exception as error:
  324. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  325. # 错误 message
  326. print(error.message)
  327. # 诊断地址
  328. print(error.data.get("Recommend"))
  329. UtilClient.assert_as_string(error.message)
  330. def extract_date_yyyymmdd(input_string):
  331. pattern = r'\d{8}'
  332. matches = re.findall(pattern, input_string)
  333. if matches:
  334. return matches[0]
  335. return None
  336. def get_online_version_dt(service_name: str):
  337. model_detail = PAIClient.get_describe_service(service_name)
  338. service_config_str = model_detail['ServiceConfig']
  339. service_config = json.loads(service_config_str)
  340. model_path = service_config['model_path']
  341. online_date = extract_date_yyyymmdd(model_path)
  342. return online_date
  343. def update_online_flow():
  344. online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
  345. draft = PAIClient.get_work_flow_draft(experiment_id)
  346. print(json.dumps(draft, ensure_ascii=False))
  347. content = draft['Content']
  348. version = draft['Version']
  349. print(content)
  350. content_json = json.loads(content)
  351. nodes = content_json.get('nodes')
  352. global_params = content_json.get('globalParams')
  353. bizdate = get_previous_days_date(1)
  354. for global_param in global_params:
  355. if global_param['name'] == 'bizdate':
  356. global_param['value'] = bizdate
  357. if global_param['name'] == 'online_version_dt':
  358. global_param['value'] = online_version_dt
  359. if global_param['name'] == 'eval_date':
  360. global_param['value'] = bizdate
  361. for node in nodes:
  362. name = node['name']
  363. if name == '模型训练-自定义':
  364. properties = node['properties']
  365. for property in properties:
  366. if property['name'] == 'sql':
  367. value = property['value']
  368. new_value = update_train_tables(value)
  369. if new_value is None:
  370. print("error")
  371. property['value'] = new_value
  372. new_content = json.dumps(content_json, ensure_ascii=False)
  373. PAIClient.update_experiment_content(experiment_id, new_content, version)
  374. def update_online_new_flow():
  375. online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
  376. draft = PAIClient.get_work_flow_draft(experiment_id)
  377. print(json.dumps(draft, ensure_ascii=False))
  378. content = draft['Content']
  379. version = draft['Version']
  380. print(content)
  381. content_json = json.loads(content)
  382. nodes = content_json.get('nodes')
  383. global_params = content_json.get('globalParams')
  384. bizdate = get_previous_days_date(1)
  385. for global_param in global_params:
  386. if global_param['name'] == 'bizdate':
  387. global_param['value'] = bizdate
  388. if global_param['name'] == 'online_version_dt':
  389. global_param['value'] = online_version_dt
  390. if global_param['name'] == 'eval_date':
  391. global_param['value'] = bizdate
  392. for node in nodes:
  393. name = node['name']
  394. if name == '样本shuffle':
  395. properties = node['properties']
  396. for property in properties:
  397. if property['name'] == 'sql':
  398. value = property['value']
  399. new_value = new_update_train_tables(value)
  400. if new_value is None:
  401. print("error")
  402. property['value'] = new_value
  403. new_content = json.dumps(content_json, ensure_ascii=False)
  404. PAIClient.update_experiment_content(experiment_id, new_content, version)
  405. def wait_job_end(job_id: str):
  406. while True:
  407. job_detail = PAIClient.get_job_detail(job_id)
  408. print(job_detail)
  409. statue = job_detail['Status']
  410. # Initialized: 初始化完成 Starting:开始 WorkflowServiceStarting:准备提交 Running:运行中 ReadyToSchedule:准备运行(前序节点未完成导致)
  411. if (statue == 'Initialized' or statue == 'Starting' or statue == 'WorkflowServiceStarting'
  412. or statue == 'Running' or statue == 'ReadyToSchedule'):
  413. # 睡眠300s 等待下次获取
  414. time.sleep(300)
  415. continue
  416. # Failed:运行失败 Terminating:终止中 Terminated:已终止 Unknown:未知 Skipped:跳过(前序节点失败导致) Succeeded:运行成功
  417. if statue == 'Failed' or statue == 'Terminating' or statue == 'Unknown' or statue == 'Skipped' or statue == 'Succeeded':
  418. return job_detail
  419. def get_node_dict():
  420. experiment_id = "draft-7u3e9v1uc5pohjl0t6"
  421. draft = PAIClient.get_work_flow_draft(experiment_id)
  422. content = draft['Content']
  423. content_json = json.loads(content)
  424. nodes = content_json.get('nodes')
  425. node_dict = {}
  426. for node in nodes:
  427. name = node['name']
  428. # 检查名称是否在目标名称集合中
  429. if name in target_names:
  430. node_dict[name] = node['id']
  431. return node_dict
  432. def train_model():
  433. node_dict = get_node_dict()
  434. train_node_id = node_dict['模型训练-自定义']
  435. execute_type = 'EXECUTE_ONE'
  436. train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  437. train_job_id = train_res['JobId']
  438. train_job_detail = wait_job_end(train_job_id)
  439. if train_job_detail['Status'] == 'Succeeded':
  440. export_node_id = node_dict['模型导出-2']
  441. export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
  442. export_job_id = export_res['JobId']
  443. export_job_detail = wait_job_end(export_job_id)
  444. if export_job_detail['Status'] == 'Succeeded':
  445. return True
  446. return False
  447. def update_online_model():
  448. node_dict = get_node_dict()
  449. experiment_id = "draft-7u3e9v1uc5pohjl0t6"
  450. train_node_id = node_dict['更新EAS服务(Beta)-1']
  451. execute_type = 'EXECUTE_ONE'
  452. train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  453. train_job_id = train_res['JobId']
  454. train_job_detail = wait_job_end(train_job_id)
  455. if train_job_detail['Status'] == 'Succeeded':
  456. return True
  457. return False
  458. def validate_model_data_accuracy():
  459. node_dict = get_node_dict()
  460. train_node_id = node_dict['虚拟起始节点']
  461. execute_type = 'EXECUTE_FROM_HERE'
  462. validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
  463. print(validate_res)
  464. validate_job_id = validate_res['JobId']
  465. print(validate_job_id)
  466. validate_job_detail = wait_job_end(validate_job_id)
  467. print('res')
  468. print(validate_job_detail)
  469. if validate_job_detail['Status'] == 'Succeeded':
  470. pipeline_run_id = validate_job_detail['RunId']
  471. node_id = validate_job_detail['PaiflowNodeId']
  472. flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
  473. print(flow_out_put_detail)
  474. tabel_dict = {}
  475. out_puts = flow_out_put_detail['Outputs']
  476. for out_put in out_puts:
  477. if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
  478. value1 = json.loads(out_put["Info"]['value'])
  479. tabel_dict['二分类评估-1'] = value1['location']['table']
  480. if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
  481. value2 = json.loads(out_put["Info"]['value'])
  482. tabel_dict['二分类评估-2'] = value2['location']['table']
  483. if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
  484. value3 = json.loads(out_put["Info"]['value'])
  485. tabel_dict['预测结果对比'] = value3['location']['table']
  486. num = 10
  487. df = get_data_from_odps('pai_algo', tabel_dict['预测结果对比'], 10)
  488. # 对指定列取绝对值再求和
  489. old_abs_avg = df['old_error'].abs().sum() / num
  490. new_abs_avg = df['new_error'].abs().sum() / num
  491. new_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-1'])['AUC']
  492. old_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-2'])['AUC']
  493. bizdate = get_previous_days_date(1)
  494. score_diff = abs(old_abs_avg - new_abs_avg)
  495. msg = ""
  496. level = ""
  497. if new_abs_avg > 0.1:
  498. msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
  499. level = 'error'
  500. elif score_diff > 0.05:
  501. msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
  502. level = 'error'
  503. else:
  504. # update_online_model()
  505. msg += 'DNN广告模型更新完成'
  506. level = 'info'
  507. step_end_time = int(time.time())
  508. elapsed = step_end_time - start_time
  509. # 初始化表格头部
  510. top10_msg = "| CID | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
  511. top10_msg += "\n| ---- | --------- | -------- |"
  512. for index, row in df.iterrows():
  513. # 获取指定列的元素
  514. cid = row['cid']
  515. old_error = row['old_error']
  516. new_error = row['new_error']
  517. top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
  518. print(top10_msg)
  519. msg += f"\n\t - 老模型AUC: {old_auc}"
  520. msg += f"\n\t - 新模型AUC: {new_auc}"
  521. msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
  522. msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
  523. _monitor(level, msg, start_time, elapsed, top10_msg)
  524. if __name__ == '__main__':
  525. start_time = int(time.time())
  526. # 1.更新工作流
  527. update_online_new_flow()
  528. # 2.训练模型
  529. train_model()
  530. # 3. 验证模型数据 & 更新模型到线上
  531. # validate_model_data_accuracy()
  532. # start_time = int(time.time())
  533. # node_dict = get_node_dict()
  534. # str = '{"Creator": "204034041838504386", "ExecuteType": "EXECUTE_FROM_HERE", "ExperimentId": "draft-7u3e9v1uc5pohjl0t6", "GmtCreateTime": "2025-04-01T03:17:42.000+00:00", "JobId": "job-8u3ev2uf5ncoexj9p9", "PaiflowNodeId": "node-9wtveoz1tu89tqfoox", "RequestId": "6ED5FFB1-346B-5075-ACC9-029EB77E9F09", "RunId": "flow-lchat027733ttstdc0", "Status": "Succeeded", "WorkspaceId": "96094"}'
  535. # validate_job_detail = json.loads(str)
  536. # if validate_job_detail['Status'] == 'Succeeded':
  537. # pipeline_run_id = validate_job_detail['RunId']
  538. # node_id = validate_job_detail['PaiflowNodeId']
  539. # flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
  540. # print(flow_out_put_detail)
  541. # tabel_dict = {}
  542. # out_puts = flow_out_put_detail['Outputs']
  543. # for out_put in out_puts:
  544. # if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
  545. # value1 = json.loads(out_put["Info"]['value'])
  546. # tabel_dict['二分类评估-1'] = value1['location']['table']
  547. # if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
  548. # value2 = json.loads(out_put["Info"]['value'])
  549. # tabel_dict['二分类评估-2'] = value2['location']['table']
  550. # if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
  551. # value3 = json.loads(out_put["Info"]['value'])
  552. # tabel_dict['预测结果对比'] = value3['location']['table']
  553. #
  554. # num = 10
  555. # df = get_data_from_odps('pai_algo', tabel_dict['预测结果对比'], 10)
  556. # # 对指定列取绝对值再求和
  557. # old_abs_avg = df['old_error'].abs().sum() / num
  558. # new_abs_avg = df['new_error'].abs().sum() / num
  559. # new_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-1'])['AUC']
  560. # old_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-2'])['AUC']
  561. # bizdate = get_previous_days_date(1)
  562. # score_diff = abs(old_abs_avg - new_abs_avg)
  563. # msg = ""
  564. # level = ""
  565. # if new_abs_avg > 0.1:
  566. # msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
  567. # level = 'error'
  568. # elif score_diff > 0.05:
  569. # msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
  570. # level = 'error'
  571. # else:
  572. # # TODO 更新模型到线上
  573. # msg += 'DNN广告模型更新完成'
  574. # level = 'info'
  575. # step_end_time = int(time.time())
  576. # elapsed = step_end_time - start_time
  577. #
  578. # # 初始化表格头部
  579. # top10_msg = "| CID | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
  580. # top10_msg += "\n| ---- | --------- | -------- |"
  581. #
  582. # for index, row in df.iterrows():
  583. # # 获取指定列的元素
  584. # cid = row['cid']
  585. # old_error = row['old_error']
  586. # new_error = row['new_error']
  587. # top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
  588. # print(top10_msg)
  589. # msg += f"\n\t - 老模型AUC: {old_auc}"
  590. # msg += f"\n\t - 新模型AUC: {new_auc}"
  591. # msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
  592. # msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
  593. # _monitor('info', msg, start_time, elapsed, top10_msg)