|
@@ -1,88 +1,70 @@
|
|
|
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
-#
|
|
|
-# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
-# you may not use this file except in compliance with the License.
|
|
|
-# You may obtain a copy of the License at
|
|
|
-#
|
|
|
-# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
-#
|
|
|
-# Unless required by applicable law or agreed to in writing, software
|
|
|
-# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
-# See the License for the specific language governing permissions and
|
|
|
-# limitations under the License.
|
|
|
+from paddle.distributed.fleet.data_generator import MultiSlotDataGenerator
|
|
|
+import sys
|
|
|
|
|
|
-import math
|
|
|
-import paddle
|
|
|
-from net import DSSMLayer
|
|
|
+class DSSMReader(MultiSlotDataGenerator):
|
|
|
+ def __init__(self):
|
|
|
+ super(DSSMReader, self).__init__()
|
|
|
+ self.feature_dim = 3 # 设置特征维度
|
|
|
|
|
|
+ def init(self, config=None):
|
|
|
+ pass
|
|
|
|
|
|
-class StaticModel():
|
|
|
- def __init__(self, config):
|
|
|
- self.cost = None
|
|
|
- self.config = config
|
|
|
- self._init_hyper_parameters()
|
|
|
+ def line_process(self, line):
|
|
|
+ try:
|
|
|
+ features = line.rstrip('\n').split('\t')
|
|
|
+ if len(features) < 3: # 确保至少有query、pos_doc和一个neg_doc
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 验证并处理query特征
|
|
|
+ query = features[0].split(',')
|
|
|
+ if len(query) != self.feature_dim:
|
|
|
+ return None
|
|
|
+ query = [float(x) for x in query]
|
|
|
|
|
|
- def _init_hyper_parameters(self):
|
|
|
- self.trigram_d = self.config.get("hyper_parameters.trigram_d")
|
|
|
- self.neg_num = self.config.get("hyper_parameters.neg_num")
|
|
|
- self.hidden_layers = self.config.get("hyper_parameters.fc_sizes")
|
|
|
- self.hidden_acts = self.config.get("hyper_parameters.fc_acts")
|
|
|
- self.learning_rate = self.config.get("hyper_parameters.learning_rate")
|
|
|
- self.slice_end = self.config.get("hyper_parameters.slice_end")
|
|
|
- self.learning_rate = self.config.get(
|
|
|
- "hyper_parameters.optimizer.learning_rate")
|
|
|
+ # 验证并处理pos_doc特征
|
|
|
+ pos_doc = features[1].split(',')
|
|
|
+ if len(pos_doc) != self.feature_dim:
|
|
|
+ return None
|
|
|
+ pos_doc = [float(x) for x in pos_doc]
|
|
|
|
|
|
- def create_feeds(self, is_infer=False):
|
|
|
- query = paddle.static.data(
|
|
|
- name="query", shape=[-1, self.trigram_d], dtype='float32')
|
|
|
- self.prune_feed_vars = [query]
|
|
|
+ # 验证并处理neg_doc特征
|
|
|
+ neg_docs = []
|
|
|
+ for i in range(2, len(features)):
|
|
|
+ neg_doc = features[i].split(',')
|
|
|
+ if len(neg_doc) != self.feature_dim:
|
|
|
+ continue
|
|
|
+ neg_docs.append([float(x) for x in neg_doc])
|
|
|
|
|
|
- doc_pos = paddle.static.data(
|
|
|
- name="doc_pos", shape=[-1, self.trigram_d], dtype='float32')
|
|
|
+ if not neg_docs: # 如果没有有效的neg_doc
|
|
|
+ return None
|
|
|
|
|
|
- if is_infer:
|
|
|
- return [query, doc_pos]
|
|
|
+ # 构建输出列表
|
|
|
+ output = []
|
|
|
+ output.append(("query", query))
|
|
|
+ output.append(("pos_doc", pos_doc))
|
|
|
+ for i, neg_doc in enumerate(neg_docs):
|
|
|
+ output.append((f"neg_doc_{i}", neg_doc))
|
|
|
|
|
|
- doc_negs = [
|
|
|
- paddle.static.data(
|
|
|
- name="doc_neg_" + str(i),
|
|
|
- shape=[-1, self.trigram_d],
|
|
|
- dtype="float32") for i in range(self.neg_num)
|
|
|
- ]
|
|
|
- feeds_list = [query, doc_pos] + doc_negs
|
|
|
- return feeds_list
|
|
|
+ return output
|
|
|
|
|
|
- def net(self, input, is_infer=False):
|
|
|
- dssm_model = DSSMLayer(self.trigram_d, self.neg_num, self.slice_end,
|
|
|
- self.hidden_layers, self.hidden_acts)
|
|
|
- R_Q_D_p, hit_prob = dssm_model.forward(input, is_infer)
|
|
|
+ except Exception as e:
|
|
|
+ sys.stderr.write(f"Error processing line: {str(e)}\n")
|
|
|
+ return None
|
|
|
|
|
|
- self.inference_target_var = R_Q_D_p
|
|
|
- self.prune_target_var = dssm_model.query_fc
|
|
|
- self.train_dump_fields = [dssm_model.query_fc, R_Q_D_p]
|
|
|
- self.train_dump_params = dssm_model.params
|
|
|
- self.infer_dump_fields = [dssm_model.doc_pos_fc]
|
|
|
- if is_infer:
|
|
|
- fetch_dict = {'query_doc_sim': R_Q_D_p}
|
|
|
- return fetch_dict
|
|
|
-
|
|
|
+ def generate_sample(self, line):
|
|
|
+ def reader():
|
|
|
+ try:
|
|
|
+ result = self.line_process(line)
|
|
|
+ if result is not None:
|
|
|
+ yield result
|
|
|
+ except Exception as e:
|
|
|
+ sys.stderr.write(f"Error in generate_sample: {str(e)}\n")
|
|
|
+ return reader
|
|
|
|
|
|
- loss = -paddle.sum(paddle.log(hit_prob), axis=-1)
|
|
|
- avg_cost = paddle.mean(x=loss)
|
|
|
- # print(avg_cost)
|
|
|
- self._cost = avg_cost
|
|
|
- fetch_dict = {'Loss': avg_cost}
|
|
|
- return fetch_dict
|
|
|
-
|
|
|
-
|
|
|
- def create_optimizer(self, strategy=None):
|
|
|
- optimizer = paddle.optimizer.Adam(
|
|
|
- learning_rate=self.learning_rate, lazy_mode=True)
|
|
|
- if strategy != None:
|
|
|
- import paddle.distributed.fleet as fleet
|
|
|
- optimizer = fleet.distributed_optimizer(optimizer, strategy)
|
|
|
- optimizer.minimize(self._cost)
|
|
|
-
|
|
|
- def infer_net(self, input):
|
|
|
- return self.net(input, is_infer=True)
|
|
|
+if __name__ == "__main__":
|
|
|
+ reader = DSSMReader()
|
|
|
+ reader.init()
|
|
|
+ try:
|
|
|
+ reader.run_from_stdin()
|
|
|
+ except Exception as e:
|
|
|
+ sys.stderr.write(f"Error in main: {str(e)}\n")
|