|
@@ -388,6 +388,13 @@ def extract_date_yyyymmdd(input_string):
|
|
return matches[0]
|
|
return matches[0]
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
+def extract_model_version(input_string):
|
|
|
|
+ pattern = r"ad_rank_dnn_([^/]+)/\d{8}"
|
|
|
|
+ matches = re.findall(pattern, input_string)
|
|
|
|
+ if matches:
|
|
|
|
+ return matches[0]
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
|
|
def get_online_model_config(service_name: str):
|
|
def get_online_model_config(service_name: str):
|
|
model_config = {}
|
|
model_config = {}
|
|
@@ -398,6 +405,8 @@ def get_online_model_config(service_name: str):
|
|
model_config['model_path'] = model_path
|
|
model_config['model_path'] = model_path
|
|
online_date = extract_date_yyyymmdd(model_path)
|
|
online_date = extract_date_yyyymmdd(model_path)
|
|
model_config['online_date'] = online_date
|
|
model_config['online_date'] = online_date
|
|
|
|
+ model_version = extract_model_version(model_path)
|
|
|
|
+ model_config['model_version'] = model_version
|
|
return model_config
|
|
return model_config
|
|
|
|
|
|
|
|
|
|
@@ -732,10 +741,13 @@ def update_trained_cids_pointer(model_name=None, dt_version=None):
|
|
elif not (model_name and dt_version):
|
|
elif not (model_name and dt_version):
|
|
# 不允许其中一个为空
|
|
# 不允许其中一个为空
|
|
raise Exception("model_name 和 dt_version 必须同时提供")
|
|
raise Exception("model_name 和 dt_version 必须同时提供")
|
|
- oss_content = f"oss://art-recommend/fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt_version}"
|
|
|
|
-
|
|
|
|
|
|
+ model_version = {}
|
|
|
|
+ model_version['modelName'] = f"model_name={model_name}"
|
|
|
|
+ model_version['dtVersion'] = f"dt_version={dt_version}"
|
|
|
|
+ model_version['timestamp'] = int(time.time())
|
|
|
|
+ print(json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8'))
|
|
bucket_name = "art-recommend"
|
|
bucket_name = "art-recommend"
|
|
- object_key = "fengzhoutian/ad_engine_files/pai_model_trained_cids_pointer"
|
|
|
|
|
|
+ object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
|
|
|
|
|
|
oss_config = oss.config.load_default()
|
|
oss_config = oss.config.load_default()
|
|
oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
|
|
oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
|
|
@@ -746,7 +758,7 @@ def update_trained_cids_pointer(model_name=None, dt_version=None):
|
|
ret = client.put_object(oss.PutObjectRequest(
|
|
ret = client.put_object(oss.PutObjectRequest(
|
|
bucket=bucket_name,
|
|
bucket=bucket_name,
|
|
key=object_key,
|
|
key=object_key,
|
|
- body=oss_content.encode('utf-8')
|
|
|
|
|
|
+ body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
|
|
))
|
|
))
|
|
print(f'status code: {ret.status_code},'
|
|
print(f'status code: {ret.status_code},'
|
|
f' request id: {ret.request_id},'
|
|
f' request id: {ret.request_id},'
|
|
@@ -784,8 +796,11 @@ if __name__ == '__main__':
|
|
print("所有函数都成功执行,可以继续下一步操作。")
|
|
print("所有函数都成功执行,可以继续下一步操作。")
|
|
result, msg, level, top10_msg = validate_model_data_accuracy()
|
|
result, msg, level, top10_msg = validate_model_data_accuracy()
|
|
if result:
|
|
if result:
|
|
- update_online_model()
|
|
|
|
- print("success")
|
|
|
|
|
|
+ update_online_res = update_online_model()
|
|
|
|
+ if update_online_res:
|
|
|
|
+ online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
|
|
|
|
+ update_trained_cids_pointer(online_model_config['model_version'], online_model_config['online_date'])
|
|
|
|
+ print("success")
|
|
step_end_time = int(time.time())
|
|
step_end_time = int(time.time())
|
|
elapsed = step_end_time - start_time
|
|
elapsed = step_end_time - start_time
|
|
print(level, msg, start_time, elapsed, top10_msg)
|
|
print(level, msg, start_time, elapsed, top10_msg)
|