| 
														
															@@ -21,22 +21,25 @@ class StaticModel(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     def create_feeds(self, is_infer=False): 
														 | 
														
														 | 
														
															     def create_feeds(self, is_infer=False): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         # 定义输入数据占位符 
														 | 
														
														 | 
														
															         # 定义输入数据占位符 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # sample_id = paddle.static.data( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        #    name="sample_id", shape=[-1, 1], dtype='int64') 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         feeds_list = [] 
														 | 
														
														 | 
														
															         feeds_list = [] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if not is_infer: 
														 | 
														
														 | 
														
															         if not is_infer: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             label = paddle.static.data( 
														 | 
														
														 | 
														
															             label = paddle.static.data( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 name="label", shape=[-1, 1], dtype='float32') 
														 | 
														
														 | 
														
															                 name="label", shape=[-1, 1], dtype='float32') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            feeds_list.append(label) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                     
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        left_features = paddle.static.data( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            name="left_features", shape=[-1, self.feature_num], dtype='float32') 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        feeds_list.append(left_features) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        right_features = paddle.static.data( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            name="right_features", shape=[-1, self.feature_num], dtype='float32') 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        feeds_list.append(right_features) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-         
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-         
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            feeds_list.append(label)      
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            left_features = paddle.static.data( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                name="left_features", shape=[-1, self.feature_num], dtype='float32') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            feeds_list.append(left_features) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            right_features = paddle.static.data( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                name="right_features", shape=[-1, self.feature_num], dtype='float32') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            feeds_list.append(right_features) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        else: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            sample_id = paddle.static.data( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                name="sample_id", shape=[-1, 1], dtype='int64') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            feeds_list.append(label)      
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            left_features = paddle.static.data( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                name="left_features", shape=[-1, self.feature_num], dtype='float32') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            feeds_list.append(left_features) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															         return feeds_list 
														 | 
														
														 | 
														
															         return feeds_list 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -52,10 +55,11 @@ class StaticModel(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         ) 
														 | 
														
														 | 
														
															         ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if is_infer: 
														 | 
														
														 | 
														
															         if is_infer: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            left_features = input 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            sample_id,left_features = input 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             left_vec = dssm_model(left_features,None,is_infer=True) 
														 | 
														
														 | 
														
															             left_vec = dssm_model(left_features,None,is_infer=True) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             self.inference_target_var = left_vec 
														 | 
														
														 | 
														
															             self.inference_target_var = left_vec 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             fetch_dict = { 
														 | 
														
														 | 
														
															             fetch_dict = { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                'sample_id': sample_id 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 'left_vector': left_vec 
														 | 
														
														 | 
														
															                 'left_vector': left_vec 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             } 
														 | 
														
														 | 
														
															             } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             return fetch_dict 
														 | 
														
														 | 
														
															             return fetch_dict 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -68,8 +72,6 @@ class StaticModel(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             self.left_vector = left_vec 
														 | 
														
														 | 
														
															             self.left_vector = left_vec 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             self.right_vector = right_vec 
														 | 
														
														 | 
														
															             self.right_vector = right_vec 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             # 计算损失 
														 | 
														
														 | 
														
															             # 计算损失 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             # 使用带margin的二元交叉熵损失 
														 | 
														
														 | 
														
															             # 使用带margin的二元交叉熵损失 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             pos_mask = paddle.cast(label > 0.5, 'float32') 
														 | 
														
														 | 
														
															             pos_mask = paddle.cast(label > 0.5, 'float32') 
														 |