| 
					
				 | 
			
			
				@@ -21,22 +21,25 @@ class StaticModel(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def create_feeds(self, is_infer=False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # 定义输入数据占位符 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # sample_id = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #    name="sample_id", shape=[-1, 1], dtype='int64') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         feeds_list = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if not is_infer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             label = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 name="label", shape=[-1, 1], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            feeds_list.append(label) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        left_features = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            name="left_features", shape=[-1, self.feature_num], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        feeds_list.append(left_features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        right_features = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            name="right_features", shape=[-1, self.feature_num], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        feeds_list.append(right_features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            feeds_list.append(label)      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            left_features = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                name="left_features", shape=[-1, self.feature_num], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            feeds_list.append(left_features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            right_features = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                name="right_features", shape=[-1, self.feature_num], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            feeds_list.append(right_features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sample_id = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                name="sample_id", shape=[-1, 1], dtype='int64') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            feeds_list.append(label)      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            left_features = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                name="left_features", shape=[-1, self.feature_num], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            feeds_list.append(left_features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return feeds_list 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -52,10 +55,11 @@ class StaticModel(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if is_infer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            left_features = input 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sample_id,left_features = input 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             left_vec = dssm_model(left_features,None,is_infer=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.inference_target_var = left_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             fetch_dict = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'sample_id': sample_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 'left_vector': left_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return fetch_dict 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -68,8 +72,6 @@ class StaticModel(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.left_vector = left_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.right_vector = right_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # 计算损失 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # 使用带margin的二元交叉熵损失 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             pos_mask = paddle.cast(label > 0.5, 'float32') 
			 |