Bladeren bron

update data process

often 7 maanden geleden
bovenliggende
commit
6aff14f2eb
1 gewijzigde bestanden met toevoegingen van 59 en 77 verwijderingen
  1. 59 77
      recommend-model-produce/src/main/python/models/dssm/bq_reader_train_ps.py

+ 59 - 77
recommend-model-produce/src/main/python/models/dssm/bq_reader_train_ps.py

@@ -1,88 +1,70 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+from paddle.distributed.fleet.data_generator import MultiSlotDataGenerator
+import sys
 
-import math
-import paddle
-from net import DSSMLayer
+class DSSMReader(MultiSlotDataGenerator):
+    def __init__(self):
+        super(DSSMReader, self).__init__()
+        self.feature_dim = 3  # 设置特征维度
 
+    def init(self, config=None):
+        pass
 
-class StaticModel():
-    def __init__(self, config):
-        self.cost = None
-        self.config = config
-        self._init_hyper_parameters()
+    def line_process(self, line):
+        try:
+            features = line.rstrip('\n').split('\t')
+            if len(features) < 3:  # 确保至少有query、pos_doc和一个neg_doc
+                return None
+            
+            # 验证并处理query特征
+            query = features[0].split(',')
+            if len(query) != self.feature_dim:
+                return None
+            query = [float(x) for x in query]
 
-    def _init_hyper_parameters(self):
-        self.trigram_d = self.config.get("hyper_parameters.trigram_d")
-        self.neg_num = self.config.get("hyper_parameters.neg_num")
-        self.hidden_layers = self.config.get("hyper_parameters.fc_sizes")
-        self.hidden_acts = self.config.get("hyper_parameters.fc_acts")
-        self.learning_rate = self.config.get("hyper_parameters.learning_rate")
-        self.slice_end = self.config.get("hyper_parameters.slice_end")
-        self.learning_rate = self.config.get(
-            "hyper_parameters.optimizer.learning_rate")
+            # 验证并处理pos_doc特征
+            pos_doc = features[1].split(',')
+            if len(pos_doc) != self.feature_dim:
+                return None
+            pos_doc = [float(x) for x in pos_doc]
 
-    def create_feeds(self, is_infer=False):
-        query = paddle.static.data(
-            name="query", shape=[-1, self.trigram_d], dtype='float32')
-        self.prune_feed_vars = [query]
+            # 验证并处理neg_doc特征
+            neg_docs = []
+            for i in range(2, len(features)):
+                neg_doc = features[i].split(',')
+                if len(neg_doc) != self.feature_dim:
+                    continue
+                neg_docs.append([float(x) for x in neg_doc])
 
-        doc_pos = paddle.static.data(
-            name="doc_pos", shape=[-1, self.trigram_d], dtype='float32')
+            if not neg_docs:  # 如果没有有效的neg_doc
+                return None
 
-        if is_infer:
-            return [query, doc_pos]
+            # 构建输出列表
+            output = []
+            output.append(("query", query))
+            output.append(("pos_doc", pos_doc))
+            for i, neg_doc in enumerate(neg_docs):
+                output.append((f"neg_doc_{i}", neg_doc))
 
-        doc_negs = [
-            paddle.static.data(
-                name="doc_neg_" + str(i),
-                shape=[-1, self.trigram_d],
-                dtype="float32") for i in range(self.neg_num)
-        ]
-        feeds_list = [query, doc_pos] + doc_negs
-        return feeds_list
+            return output
 
-    def net(self, input, is_infer=False):
-        dssm_model = DSSMLayer(self.trigram_d, self.neg_num, self.slice_end,
-                               self.hidden_layers, self.hidden_acts)
-        R_Q_D_p, hit_prob = dssm_model.forward(input, is_infer)
+        except Exception as e:
+            sys.stderr.write(f"Error processing line: {str(e)}\n")
+            return None
 
-        self.inference_target_var = R_Q_D_p
-        self.prune_target_var = dssm_model.query_fc
-        self.train_dump_fields = [dssm_model.query_fc, R_Q_D_p]
-        self.train_dump_params = dssm_model.params
-        self.infer_dump_fields = [dssm_model.doc_pos_fc]
-        if is_infer:
-            fetch_dict = {'query_doc_sim': R_Q_D_p}
-            return fetch_dict
-        
+    def generate_sample(self, line):
+        def reader():
+            try:
+                result = self.line_process(line)
+                if result is not None:
+                    yield result
+            except Exception as e:
+                sys.stderr.write(f"Error in generate_sample: {str(e)}\n")
+        return reader
 
-        loss = -paddle.sum(paddle.log(hit_prob), axis=-1)
-        avg_cost = paddle.mean(x=loss)
-        # print(avg_cost)
-        self._cost = avg_cost
-        fetch_dict = {'Loss': avg_cost}
-        return fetch_dict
-
-
-    def create_optimizer(self, strategy=None):
-        optimizer = paddle.optimizer.Adam(
-            learning_rate=self.learning_rate, lazy_mode=True)
-        if strategy != None:
-            import paddle.distributed.fleet as fleet
-            optimizer = fleet.distributed_optimizer(optimizer, strategy)
-        optimizer.minimize(self._cost)
-
-    def infer_net(self, input):
-        return self.net(input, is_infer=True)
+if __name__ == "__main__":
+    reader = DSSMReader()
+    reader.init()
+    try:
+        reader.run_from_stdin()
+    except Exception as e:
+        sys.stderr.write(f"Error in main: {str(e)}\n")