| 
					
				 | 
			
			
				@@ -74,7 +74,7 @@ class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def handler(self, fetch_vars): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """处理每批次的推理结果""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         result_dict = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for var_name, var_value in fetch_vars.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # 转换数据类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if isinstance(var_value, np.ndarray): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -99,6 +99,7 @@ class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.current_batch = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def finish(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.info("InferenceFetchHandler finish") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """确保所有剩余结果都被保存""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if self.current_batch: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self._write_batch() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -233,17 +234,15 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         init_model_path = config.get("runner.infer_load_path") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model_mode = config.get("runner.model_mode", 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         client = HangZhouOSSClient("art-recommend") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        client.get_object_to_file("lqc/64.tar.gz", "64.tar.gz") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        compress.uncompress_tar("64.tar.gz", init_model_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        client.get_object_to_file("oss_object_name", "model.tar.gz") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        compress.uncompress_tar("model.tar.gz", init_model_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         assert os.path.exists(init_model_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         #if fleet.is_first_worker(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         #fleet.load_inference_model(init_model_path, mode=int(model_mode)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         #fleet.barrier_worker() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        save_model_path = self.config.get("runner.model_save_path") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if save_model_path and (not os.path.exists(save_model_path)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            os.makedirs(save_model_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         reader_type = self.config.get("runner.reader_type", "QueueDataset") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         epochs = int(self.config.get("runner.epochs")) 
			 |