| 
														
															@@ -74,7 +74,7 @@ class InferenceFetchHandler(FetchHandler): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     def handler(self, fetch_vars): 
														 | 
														
														 | 
														
															     def handler(self, fetch_vars): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         """处理每批次的推理结果""" 
														 | 
														
														 | 
														
															         """处理每批次的推理结果""" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         result_dict = {} 
														 | 
														
														 | 
														
															         result_dict = {} 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        print("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         for var_name, var_value in fetch_vars.items(): 
														 | 
														
														 | 
														
															         for var_name, var_value in fetch_vars.items(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             # 转换数据类型 
														 | 
														
														 | 
														
															             # 转换数据类型 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             if isinstance(var_value, np.ndarray): 
														 | 
														
														 | 
														
															             if isinstance(var_value, np.ndarray): 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -99,6 +99,7 @@ class InferenceFetchHandler(FetchHandler): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         self.current_batch = [] 
														 | 
														
														 | 
														
															         self.current_batch = [] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															      
														 | 
														
														 | 
														
															      
														 | 
													
												
											
												
													
														| 
														 | 
														
															     def finish(self): 
														 | 
														
														 | 
														
															     def finish(self): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        logger.info("InferenceFetchHandler finish") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         """确保所有剩余结果都被保存""" 
														 | 
														
														 | 
														
															         """确保所有剩余结果都被保存""" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if self.current_batch: 
														 | 
														
														 | 
														
															         if self.current_batch: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             self._write_batch() 
														 | 
														
														 | 
														
															             self._write_batch() 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -233,17 +234,15 @@ class Main(object): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         init_model_path = config.get("runner.infer_load_path") 
														 | 
														
														 | 
														
															         init_model_path = config.get("runner.infer_load_path") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         model_mode = config.get("runner.model_mode", 0) 
														 | 
														
														 | 
														
															         model_mode = config.get("runner.model_mode", 0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         client = HangZhouOSSClient("art-recommend") 
														 | 
														
														 | 
														
															         client = HangZhouOSSClient("art-recommend") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        client.get_object_to_file("lqc/64.tar.gz", "64.tar.gz") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        compress.uncompress_tar("64.tar.gz", init_model_path) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        client.get_object_to_file("oss_object_name", "model.tar.gz") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        compress.uncompress_tar("model.tar.gz", init_model_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         assert os.path.exists(init_model_path) 
														 | 
														
														 | 
														
															         assert os.path.exists(init_model_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															         #if fleet.is_first_worker(): 
														 | 
														
														 | 
														
															         #if fleet.is_first_worker(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         #fleet.load_inference_model(init_model_path, mode=int(model_mode)) 
														 | 
														
														 | 
														
															         #fleet.load_inference_model(init_model_path, mode=int(model_mode)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         #fleet.barrier_worker() 
														 | 
														
														 | 
														
															         #fleet.barrier_worker() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        save_model_path = self.config.get("runner.model_save_path") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        if save_model_path and (not os.path.exists(save_model_path)): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            os.makedirs(save_model_path) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															         reader_type = self.config.get("runner.reader_type", "QueueDataset") 
														 | 
														
														 | 
														
															         reader_type = self.config.get("runner.reader_type", "QueueDataset") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         epochs = int(self.config.get("runner.epochs")) 
														 | 
														
														 | 
														
															         epochs = int(self.config.get("runner.epochs")) 
														 |