丁云鹏 vor 6 Monaten
Ursprung
Commit
df896c0bf7

+ 3 - 0
recommend-model-produce/src/main/python/models/dnn/config.yaml

@@ -42,6 +42,9 @@ runner:
   sync_mode: "async"
   split_file_list: False
   thread_num: 1
+  upload_oss: true
+  oss_filename: dnn.tar.gz
+  oss_path: /dyp
 
 
 # hyper parameters of user-defined network

+ 11 - 0
recommend-model-produce/src/main/python/tools/static_trainer.py

@@ -19,6 +19,9 @@ import warnings
 import logging
 import paddle
 import sys
+from utils.oss_client import HangZhouOSSClient
+import utils.compress as compress
+
 __dir__ = os.path.dirname(os.path.abspath(__file__))
 #sys.path.append(__dir__)
 sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
@@ -94,6 +97,9 @@ def main(args):
     paddle.seed(seed)
     use_save_data = config.get("runner.use_save_data", False)
     os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1))
+    upload_oss = config.get("runner.upload_oss", True)
+    oss_object_name = config.get("runner.oss_object_name", "")
+
     logger.info("**************common.configs**********")
     logger.info(
         "use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
@@ -207,6 +213,11 @@ def main(args):
                 model_save_path,
                 epoch_id,
                 prefix='rec_static')
+        if(upload_oss):
+            compress.compress_tar(model_save_path, model_save_path + ".tar.gz")
+            client = HangZhouOSSClient("art-recommend")
+            client.put_object_from_file(model_save_path + ".tar.gz", oss_object_name)
+            logger.info("file {} upload success".format(model_save_path + ".tar.gz"))
         if use_save_data:
             save_data(fetch_batch_var, model_save_path)