| 
					
				 | 
			
			
				@@ -19,6 +19,9 @@ import warnings 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import paddle 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import sys 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from utils.oss_client import HangZhouOSSClient 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import utils.compress as compress 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 __dir__ = os.path.dirname(os.path.abspath(__file__)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 #sys.path.append(__dir__) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -94,6 +97,9 @@ def main(args): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     paddle.seed(seed) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     use_save_data = config.get("runner.use_save_data", False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    upload_oss = config.get("runner.upload_oss", True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    oss_object_name = config.get("runner.oss_object_name", "") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     logger.info("**************common.configs**********") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     logger.info( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         "use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}". 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -207,6 +213,11 @@ def main(args): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 model_save_path, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 epoch_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 prefix='rec_static') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if(upload_oss): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            compress.compress_tar(model_save_path, model_save_path + ".tar.gz") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            client = HangZhouOSSClient("art-recommend") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            client.put_object_from_file(model_save_path + ".tar.gz", oss_object_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.info("file {} upload success".format(model_save_path + ".tar.gz")) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if use_save_data: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             save_data(fetch_batch_var, model_save_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |