Selaa lähdekoodia

模型更新后,更新cid列表版本

xueyiming 2 päivää sitten
vanhempi
commit
5be9866c3b
1 muutettua tiedostoa jossa 21 lisäystä ja 6 poistoa
  1. 21 6
      ad/pai_flow_operator.py

+ 21 - 6
ad/pai_flow_operator.py

@@ -388,6 +388,13 @@ def extract_date_yyyymmdd(input_string):
         return matches[0]
     return None
 
+def extract_model_version(input_string):
+    pattern = r"ad_rank_dnn_([^/]+)/\d{8}"
+    matches = re.findall(pattern, input_string)
+    if matches:
+        return matches[0]
+    return None
+
 
 def get_online_model_config(service_name: str):
     model_config = {}
@@ -398,6 +405,8 @@ def get_online_model_config(service_name: str):
     model_config['model_path'] = model_path
     online_date = extract_date_yyyymmdd(model_path)
     model_config['online_date'] = online_date
+    model_version = extract_model_version(model_path)
+    model_config['model_version'] = model_version
     return model_config
 
 
@@ -732,10 +741,13 @@ def update_trained_cids_pointer(model_name=None, dt_version=None):
     elif not (model_name and dt_version):
         # 不允许其中一个为空
         raise Exception("model_name 和 dt_version 必须同时提供")
-    oss_content = f"oss://art-recommend/fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt_version}"
-
+    model_version = {}
+    model_version['modelName'] = f"model_name={model_name}"
+    model_version['dtVersion'] = f"dt_version={dt_version}"
+    model_version['timestamp'] = int(time.time())
+    print(json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8'))
     bucket_name = "art-recommend"
-    object_key = "fengzhoutian/ad_engine_files/pai_model_trained_cids_pointer"
+    object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
 
     oss_config = oss.config.load_default()
     oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
@@ -746,7 +758,7 @@ def update_trained_cids_pointer(model_name=None, dt_version=None):
     ret = client.put_object(oss.PutObjectRequest(
         bucket=bucket_name,
         key=object_key,
-        body=oss_content.encode('utf-8')
+        body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
     ))
     print(f'status code: {ret.status_code},'
           f' request id: {ret.request_id},'
@@ -784,8 +796,11 @@ if __name__ == '__main__':
         print("所有函数都成功执行,可以继续下一步操作。")
         result, msg, level, top10_msg = validate_model_data_accuracy()
         if result:
-            update_online_model()
-            print("success")
+            update_online_res = update_online_model()
+            if update_online_res:
+                online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
+                update_trained_cids_pointer(online_model_config['model_version'], online_model_config['online_date'])
+                print("success")
         step_end_time = int(time.time())
         elapsed = step_end_time - start_time
         print(level, msg, start_time, elapsed, top10_msg)