Browse Source

dssm 分布式

often 7 months ago
parent
commit
fa14750f13

+ 88 - 0
recommend-model-produce/src/main/python/models/dssm/bq_reader_train_ps.py

@@ -0,0 +1,88 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import paddle
+from net import DSSMLayer
+
+
+class StaticModel():
+    def __init__(self, config):
+        self.cost = None
+        self.config = config
+        self._init_hyper_parameters()
+
+    def _init_hyper_parameters(self):
+        self.trigram_d = self.config.get("hyper_parameters.trigram_d")
+        self.neg_num = self.config.get("hyper_parameters.neg_num")
+        self.hidden_layers = self.config.get("hyper_parameters.fc_sizes")
+        self.hidden_acts = self.config.get("hyper_parameters.fc_acts")
+        self.learning_rate = self.config.get("hyper_parameters.learning_rate")
+        self.slice_end = self.config.get("hyper_parameters.slice_end")
+        self.learning_rate = self.config.get(
+            "hyper_parameters.optimizer.learning_rate")
+
+    def create_feeds(self, is_infer=False):
+        query = paddle.static.data(
+            name="query", shape=[-1, self.trigram_d], dtype='float32')
+        self.prune_feed_vars = [query]
+
+        doc_pos = paddle.static.data(
+            name="doc_pos", shape=[-1, self.trigram_d], dtype='float32')
+
+        if is_infer:
+            return [query, doc_pos]
+
+        doc_negs = [
+            paddle.static.data(
+                name="doc_neg_" + str(i),
+                shape=[-1, self.trigram_d],
+                dtype="float32") for i in range(self.neg_num)
+        ]
+        feeds_list = [query, doc_pos] + doc_negs
+        return feeds_list
+
+    def net(self, input, is_infer=False):
+        dssm_model = DSSMLayer(self.trigram_d, self.neg_num, self.slice_end,
+                               self.hidden_layers, self.hidden_acts)
+        R_Q_D_p, hit_prob = dssm_model.forward(input, is_infer)
+
+        self.inference_target_var = R_Q_D_p
+        self.prune_target_var = dssm_model.query_fc
+        self.train_dump_fields = [dssm_model.query_fc, R_Q_D_p]
+        self.train_dump_params = dssm_model.params
+        self.infer_dump_fields = [dssm_model.doc_pos_fc]
+        if is_infer:
+            fetch_dict = {'query_doc_sim': R_Q_D_p}
+            return fetch_dict
+        
+
+        loss = -paddle.sum(paddle.log(hit_prob), axis=-1)
+        avg_cost = paddle.mean(x=loss)
+        # print(avg_cost)
+        self._cost = avg_cost
+        fetch_dict = {'Loss': avg_cost}
+        return fetch_dict
+
+
+    def create_optimizer(self, strategy=None):
+        optimizer = paddle.optimizer.Adam(
+            learning_rate=self.learning_rate, lazy_mode=True)
+        if strategy != None:
+            import paddle.distributed.fleet as fleet
+            optimizer = fleet.distributed_optimizer(optimizer, strategy)
+        optimizer.minimize(self._cost)
+
+    def infer_net(self, input):
+        return self.net(input, is_infer=True)

+ 46 - 0
recommend-model-produce/src/main/python/models/dssm/config_ps.yaml

@@ -0,0 +1,46 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+runner:
+  train_data_dir: "data/train"
+  train_reader_path: "bq_reader_train"  # importlib format
+  train_batch_size: 8
+  model_save_path: "output_model_dssm"
+
+  reader_type: "QueueDataset"  # DataLoader / QueueDataset / RecDataset
+  pipe_command: "python bq_reader_train_ps.py"
+  thread_num: 1
+  sync_mode: "sync"
+
+  use_gpu: False
+  epochs: 10
+  print_interval: 10
+  
+  test_data_dir: "data/test"
+  infer_reader_path: "bq_reader_infer"  # importlib format
+  infer_batch_size: 1
+  infer_load_path: "output_model_dssm"
+  infer_start_epoch: 0
+  infer_end_epoch: 1
+
+hyper_parameters:
+  optimizer:
+    class: adam
+    learning_rate: 0.001
+    strategy: sync
+  trigram_d: 2900
+  neg_num: 1
+  slice_end: 8
+  fc_sizes: [300, 300, 128]
+  fc_acts: ['relu', 'relu', 'relu']

+ 2 - 2
recommend-model-produce/src/main/python/tools/static_ps_trainer.py

@@ -121,7 +121,7 @@ class Main(object):
     def network(self):
         self.model = get_model(self.config)
         self.input_data = self.model.create_feeds()
-        self.inference_feed_var = self.model.create_feeds(is_infer=False)
+        self.inference_feed_var = self.model.create_feeds(is_infer=True)
         self.init_reader()
         self.metrics = self.model.net(self.input_data)
         self.inference_target_var = self.model.inference_target_var
@@ -320,4 +320,4 @@ if __name__ == "__main__":
     config = parse_args()
     os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
     benchmark_main = Main(config)
-    benchmark_main.run()
+    benchmark_main.run()