|
@@ -19,6 +19,9 @@ import warnings
|
|
|
import logging
|
|
|
import paddle
|
|
|
import sys
|
|
|
+from utils.oss_client import HangZhouOSSClient
|
|
|
+import utils.compress as compress
|
|
|
+
|
|
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
|
#sys.path.append(__dir__)
|
|
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
|
@@ -94,6 +97,9 @@ def main(args):
|
|
|
paddle.seed(seed)
|
|
|
use_save_data = config.get("runner.use_save_data", False)
|
|
|
os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1))
|
|
|
+ upload_oss = config.get("runner.upload_oss", True)
|
|
|
+ oss_object_name = config.get("runner.oss_object_name", "")
|
|
|
+
|
|
|
logger.info("**************common.configs**********")
|
|
|
logger.info(
|
|
|
"use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
|
|
@@ -207,6 +213,11 @@ def main(args):
|
|
|
model_save_path,
|
|
|
epoch_id,
|
|
|
prefix='rec_static')
|
|
|
+ if(upload_oss):
|
|
|
+ compress.compress_tar(model_save_path, model_save_path + ".tar.gz")
|
|
|
+ client = HangZhouOSSClient("art-recommend")
|
|
|
+ client.put_object_from_file(model_save_path + ".tar.gz", oss_object_name)
|
|
|
+ logger.info("file {} upload success".format(model_save_path + ".tar.gz"))
|
|
|
if use_save_data:
|
|
|
save_data(fetch_batch_var, model_save_path)
|
|
|
|