| 
					
				 | 
			
			
				@@ -1,88 +1,70 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# Licensed under the Apache License, Version 2.0 (the "License"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# you may not use this file except in compliance with the License. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# You may obtain a copy of the License at 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     http://www.apache.org/licenses/LICENSE-2.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# Unless required by applicable law or agreed to in writing, software 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# distributed under the License is distributed on an "AS IS" BASIS, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# See the License for the specific language governing permissions and 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# limitations under the License. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from paddle.distributed.fleet.data_generator import MultiSlotDataGenerator 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import sys 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import math 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import paddle 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from net import DSSMLayer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class DSSMReader(MultiSlotDataGenerator): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        super(DSSMReader, self).__init__() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.feature_dim = 3  # 设置特征维度 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def init(self, config=None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        pass 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-class StaticModel(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __init__(self, config): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.cost = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.config = config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self._init_hyper_parameters() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def line_process(self, line): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            features = line.rstrip('\n').split('\t') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if len(features) < 3:  # 确保至少有query、pos_doc和一个neg_doc 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # 验证并处理query特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            query = features[0].split(',') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if len(query) != self.feature_dim: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            query = [float(x) for x in query] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def _init_hyper_parameters(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.trigram_d = self.config.get("hyper_parameters.trigram_d") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.neg_num = self.config.get("hyper_parameters.neg_num") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.hidden_layers = self.config.get("hyper_parameters.fc_sizes") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.hidden_acts = self.config.get("hyper_parameters.fc_acts") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.learning_rate = self.config.get("hyper_parameters.learning_rate") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.slice_end = self.config.get("hyper_parameters.slice_end") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.learning_rate = self.config.get( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            "hyper_parameters.optimizer.learning_rate") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # 验证并处理pos_doc特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            pos_doc = features[1].split(',') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if len(pos_doc) != self.feature_dim: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            pos_doc = [float(x) for x in pos_doc] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def create_feeds(self, is_infer=False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        query = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            name="query", shape=[-1, self.trigram_d], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.prune_feed_vars = [query] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # 验证并处理neg_doc特征 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            neg_docs = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for i in range(2, len(features)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                neg_doc = features[i].split(',') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if len(neg_doc) != self.feature_dim: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                neg_docs.append([float(x) for x in neg_doc]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        doc_pos = paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            name="doc_pos", shape=[-1, self.trigram_d], dtype='float32') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if not neg_docs:  # 如果没有有效的neg_doc 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if is_infer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return [query, doc_pos] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # 构建输出列表 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            output = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            output.append(("query", query)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            output.append(("pos_doc", pos_doc)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for i, neg_doc in enumerate(neg_docs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                output.append((f"neg_doc_{i}", neg_doc)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        doc_negs = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            paddle.static.data( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                name="doc_neg_" + str(i), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                shape=[-1, self.trigram_d], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                dtype="float32") for i in range(self.neg_num) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        feeds_list = [query, doc_pos] + doc_negs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return feeds_list 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return output 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def net(self, input, is_infer=False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        dssm_model = DSSMLayer(self.trigram_d, self.neg_num, self.slice_end, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                               self.hidden_layers, self.hidden_acts) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        R_Q_D_p, hit_prob = dssm_model.forward(input, is_infer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        except Exception as e: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            sys.stderr.write(f"Error processing line: {str(e)}\n") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.inference_target_var = R_Q_D_p 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.prune_target_var = dssm_model.query_fc 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.train_dump_fields = [dssm_model.query_fc, R_Q_D_p] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.train_dump_params = dssm_model.params 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.infer_dump_fields = [dssm_model.doc_pos_fc] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if is_infer: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            fetch_dict = {'query_doc_sim': R_Q_D_p} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return fetch_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def generate_sample(self, line): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        def reader(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                result = self.line_process(line) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if result is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    yield result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            except Exception as e: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sys.stderr.write(f"Error in generate_sample: {str(e)}\n") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return reader 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        loss = -paddle.sum(paddle.log(hit_prob), axis=-1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        avg_cost = paddle.mean(x=loss) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # print(avg_cost) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self._cost = avg_cost 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fetch_dict = {'Loss': avg_cost} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return fetch_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def create_optimizer(self, strategy=None): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        optimizer = paddle.optimizer.Adam( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            learning_rate=self.learning_rate, lazy_mode=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if strategy != None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            import paddle.distributed.fleet as fleet 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            optimizer = fleet.distributed_optimizer(optimizer, strategy) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        optimizer.minimize(self._cost) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def infer_net(self, input): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return self.net(input, is_infer=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if __name__ == "__main__": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    reader = DSSMReader() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    reader.init() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        reader.run_from_stdin() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    except Exception as e: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sys.stderr.write(f"Error in main: {str(e)}\n") 
			 |