12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import paddle
- from net import DSSMLayer
- class StaticModel():
- def __init__(self, config):
- self.cost = None
- self.config = config
- self._init_hyper_parameters()
- def _init_hyper_parameters(self):
- self.trigram_d = self.config.get("hyper_parameters.trigram_d")
- self.neg_num = self.config.get("hyper_parameters.neg_num")
- self.hidden_layers = self.config.get("hyper_parameters.fc_sizes")
- self.hidden_acts = self.config.get("hyper_parameters.fc_acts")
- self.learning_rate = self.config.get("hyper_parameters.learning_rate")
- self.slice_end = self.config.get("hyper_parameters.slice_end")
- self.learning_rate = self.config.get(
- "hyper_parameters.optimizer.learning_rate")
- def create_feeds(self, is_infer=False):
- query = paddle.static.data(
- name="query", shape=[-1, self.trigram_d], dtype='float32')
- self.prune_feed_vars = [query]
- doc_pos = paddle.static.data(
- name="doc_pos", shape=[-1, self.trigram_d], dtype='float32')
- if is_infer:
- return [query, doc_pos]
- doc_negs = [
- paddle.static.data(
- name="doc_neg_" + str(i),
- shape=[-1, self.trigram_d],
- dtype="float32") for i in range(self.neg_num)
- ]
- feeds_list = [query, doc_pos] + doc_negs
- return feeds_list
- def net(self, input, is_infer=False):
- dssm_model = DSSMLayer(self.trigram_d, self.neg_num, self.slice_end,
- self.hidden_layers, self.hidden_acts)
- R_Q_D_p, hit_prob = dssm_model.forward(input, is_infer)
- self.inference_target_var = R_Q_D_p
- self.prune_target_var = dssm_model.query_fc
- self.train_dump_fields = [dssm_model.query_fc, R_Q_D_p]
- self.train_dump_params = dssm_model.params
- self.infer_dump_fields = [dssm_model.doc_pos_fc]
- if is_infer:
- fetch_dict = {'query_doc_sim': R_Q_D_p}
- return fetch_dict
- loss = -paddle.sum(paddle.log(hit_prob), axis=-1)
- avg_cost = paddle.mean(x=loss)
- # print(avg_cost)
- self._cost = avg_cost
- fetch_dict = {'Loss': avg_cost}
- return fetch_dict
- def create_optimizer(self, strategy=None):
- optimizer = paddle.optimizer.Adam(
- learning_rate=self.learning_rate, lazy_mode=True)
- if strategy != None:
- import paddle.distributed.fleet as fleet
- optimizer = fleet.distributed_optimizer(optimizer, strategy)
- optimizer.minimize(self._cost)
- def infer_net(self, input):
- return self.net(input, is_infer=True)
|