浏览代码

Merge branch 'dev-xym-update-pai' of algorithm/recommend-emr-dataprocess into feature/20250104-zt-update

fengzhoutian 15 小时之前
父节点
当前提交
a50376b6cd
共有 1 个文件被更改,包括 11 次插入7 次删除
  1. 11 7
      ad/pai_flow_operator.py

+ 11 - 7
ad/pai_flow_operator.py

@@ -388,7 +388,6 @@ def extract_date_yyyymmdd(input_string):
         return matches[0]
     return None
 
-
 def get_online_model_config(service_name: str):
     model_config = {}
     model_detail = PAIClient.get_describe_service(service_name)
@@ -732,10 +731,13 @@ def update_trained_cids_pointer(model_name=None, dt_version=None):
     elif not (model_name and dt_version):
         # 不允许其中一个为空
         raise Exception("model_name 和 dt_version 必须同时提供")
-    oss_content = f"oss://art-recommend/fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt_version}"
-
+    model_version = {}
+    model_version['modelName'] = f"model_name={model_name}"
+    model_version['dtVersion'] = f"dt_version={dt_version}"
+    model_version['timestamp'] = int(time.time())
+    print(json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8'))
     bucket_name = "art-recommend"
-    object_key = "fengzhoutian/ad_engine_files/pai_model_trained_cids_pointer"
+    object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
 
     oss_config = oss.config.load_default()
     oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
@@ -746,7 +748,7 @@ def update_trained_cids_pointer(model_name=None, dt_version=None):
     ret = client.put_object(oss.PutObjectRequest(
         bucket=bucket_name,
         key=object_key,
-        body=oss_content.encode('utf-8')
+        body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
     ))
     print(f'status code: {ret.status_code},'
           f' request id: {ret.request_id},'
@@ -784,8 +786,10 @@ if __name__ == '__main__':
         print("所有函数都成功执行,可以继续下一步操作。")
         result, msg, level, top10_msg = validate_model_data_accuracy()
         if result:
-            update_online_model()
-            print("success")
+            update_online_res = update_online_model()
+            if update_online_res:
+                update_trained_cids_pointer()
+                print("success")
         step_end_time = int(time.time())
         elapsed = step_end_time - start_time
         print(level, msg, start_time, elapsed, top10_msg)