| 
														
															@@ -19,6 +19,9 @@ import warnings 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import logging 
														 | 
														
														 | 
														
															 import logging 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import paddle 
														 | 
														
														 | 
														
															 import paddle 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import sys 
														 | 
														
														 | 
														
															 import sys 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+from utils.oss_client import HangZhouOSSClient 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import utils.compress as compress 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 __dir__ = os.path.dirname(os.path.abspath(__file__)) 
														 | 
														
														 | 
														
															 __dir__ = os.path.dirname(os.path.abspath(__file__)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 #sys.path.append(__dir__) 
														 | 
														
														 | 
														
															 #sys.path.append(__dir__) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) 
														 | 
														
														 | 
														
															 sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -94,6 +97,9 @@ def main(args): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     paddle.seed(seed) 
														 | 
														
														 | 
														
															     paddle.seed(seed) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     use_save_data = config.get("runner.use_save_data", False) 
														 | 
														
														 | 
														
															     use_save_data = config.get("runner.use_save_data", False) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1)) 
														 | 
														
														 | 
														
															     os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    upload_oss = config.get("runner.upload_oss", True) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    oss_object_name = config.get("runner.oss_object_name", "") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     logger.info("**************common.configs**********") 
														 | 
														
														 | 
														
															     logger.info("**************common.configs**********") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     logger.info( 
														 | 
														
														 | 
														
															     logger.info( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         "use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}". 
														 | 
														
														 | 
														
															         "use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}". 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -207,6 +213,11 @@ def main(args): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 model_save_path, 
														 | 
														
														 | 
														
															                 model_save_path, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 epoch_id, 
														 | 
														
														 | 
														
															                 epoch_id, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 prefix='rec_static') 
														 | 
														
														 | 
														
															                 prefix='rec_static') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if(upload_oss): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            compress.compress_tar(model_save_path, model_save_path + ".tar.gz") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            client = HangZhouOSSClient("art-recommend") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            client.put_object_from_file(model_save_path + ".tar.gz", oss_object_name) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            logger.info("file {} upload success".format(model_save_path + ".tar.gz")) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if use_save_data: 
														 | 
														
														 | 
														
															         if use_save_data: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             save_data(fetch_batch_var, model_save_path) 
														 | 
														
														 | 
														
															             save_data(fetch_batch_var, model_save_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 |