| 
					
				 | 
			
			
				@@ -1,22 +1,7 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# Licensed under the Apache License, Version 2.0 (the "License"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# you may not use this file except in compliance with the License. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# You may obtain a copy of the License at 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     http://www.apache.org/licenses/LICENSE-2.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# Unless required by applicable law or agreed to in writing, software 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# distributed under the License is distributed on an "AS IS" BASIS, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# See the License for the specific language governing permissions and 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# limitations under the License. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import math 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import paddle 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from net import DSSMLayer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class StaticModel(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__(self, config): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.cost = None 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -24,59 +9,96 @@ class StaticModel(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self._init_hyper_parameters() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _init_hyper_parameters(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.trigram_d = self.config.get("hyper_parameters.trigram_d") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.neg_num = self.config.get("hyper_parameters.neg_num") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.hidden_layers = self.config.get("hyper_parameters.fc_sizes") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.hidden_acts = self.config.get("hyper_parameters.fc_acts") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.learning_rate = self.config.get("hyper_parameters.learning_rate") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.slice_end = self.config.get("hyper_parameters.slice_end") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.learning_rate = self.config.get( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "hyper_parameters.optimizer.learning_rate") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 修改超参数初始化 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.feature_num = self.config.get("hyper_parameters.feature_num", 5) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.embedding_dim = self.config.get("hyper_parameters.embedding_dim", 8) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.output_dim = self.config.get("hyper_parameters.output_dim", 16) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.hidden_layers = self.config.get("hyper_parameters.hidden_layers", [64, 32]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.hidden_acts = self.config.get("hyper_parameters.hidden_acts", ["relu", "relu"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.learning_rate = self.config.get("hyper_parameters.optimizer.learning_rate", 0.001) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.margin = self.config.get("hyper_parameters.margin", 0.3)  # 用于损失函数的margin参数 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def create_feeds(self, is_infer=False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        query = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            name="query", shape=[-1, self.trigram_d], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.prune_feed_vars = [query] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        doc_pos = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            name="doc_pos", shape=[-1, self.trigram_d], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 定义输入数据占位符 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sample_id = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            name="sample_id", shape=[-1, 1], dtype='int64') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        left_features = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            name="left_features", shape=[-1, self.feature_num], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        right_features = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            name="right_features", shape=[-1, self.feature_num], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if is_infer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return [query, doc_pos] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        feeds_list = [sample_id, left_features, right_features] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if not is_infer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            label = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                name="label", shape=[-1, 1], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            feeds_list.append(label) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        doc_negs = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                name="doc_neg_" + str(i), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                shape=[-1, self.trigram_d], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                dtype="float32") for i in range(self.neg_num) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        feeds_list = [query, doc_pos] + doc_negs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return feeds_list 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def net(self, input, is_infer=False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        dssm_model = DSSMLayer(self.trigram_d, self.neg_num, self.slice_end, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                               self.hidden_layers, self.hidden_acts) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        R_Q_D_p, hit_prob = dssm_model.forward(input, is_infer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 创建模型实例 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dssm_model = DSSMLayer( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            feature_num=self.feature_num, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            embedding_dim=self.embedding_dim, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            output_dim=self.output_dim, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            hidden_layers=self.hidden_layers, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            hidden_acts=self.hidden_acts 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.inference_target_var = R_Q_D_p 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.prune_target_var = dssm_model.query_fc 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.train_dump_fields = [dssm_model.query_fc, R_Q_D_p] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.train_dump_params = dssm_model.params 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.infer_dump_fields = [dssm_model.doc_pos_fc] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if is_infer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            fetch_dict = {'query_doc_sim': R_Q_D_p} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sample_id, left_features, right_features = input 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sample_id, left_features, right_features, label = input 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 获取相似度和特征向量 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sim_score, left_vec, right_vec = dssm_model(left_features, right_features) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.inference_target_var = sim_score 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.left_vector = left_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.right_vector = right_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if is_infer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            fetch_dict = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'sample_id': sample_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'similarity': sim_score, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'left_vector': left_vec, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'right_vector': right_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return fetch_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        loss = -paddle.sum(paddle.log(hit_prob), axis=-1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        avg_cost = paddle.mean(x=loss) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # print(avg_cost) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 计算损失 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 使用带margin的二元交叉熵损失 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        pos_mask = paddle.cast(label > 0.5, 'float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        neg_mask = 1.0 - pos_mask 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        positive_loss = -pos_mask * paddle.log(paddle.clip(sim_score, 1e-8, 1.0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        negative_loss = -neg_mask * paddle.log(paddle.clip(1 - sim_score + self.margin, 1e-8, 1.0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        loss = positive_loss + negative_loss 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        avg_cost = paddle.mean(loss) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self._cost = avg_cost 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fetch_dict = {'Loss': avg_cost} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 计算accuracy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        predictions = paddle.cast(sim_score > 0.5, 'float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        accuracy = paddle.mean(paddle.cast(paddle.equal(predictions, label), 'float32')) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_dict = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'loss': avg_cost, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'accuracy': accuracy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            #'similarity': sim_score, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            #'left_vector': left_vec, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            #'right_vector': right_vec 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return fetch_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def create_optimizer(self, strategy=None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         optimizer = paddle.optimizer.Adam( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            learning_rate=self.learning_rate, lazy_mode=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if strategy != None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            learning_rate=self.learning_rate) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if strategy is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             import paddle.distributed.fleet as fleet 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             optimizer = fleet.distributed_optimizer(optimizer, strategy) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         optimizer.minimize(self._cost) 
			 |