|
@@ -0,0 +1,88 @@
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+import math
|
|
|
+import paddle
|
|
|
+from net import DSSMLayer
|
|
|
+
|
|
|
+
|
|
|
+class StaticModel():
|
|
|
+ def __init__(self, config):
|
|
|
+ self.cost = None
|
|
|
+ self.config = config
|
|
|
+ self._init_hyper_parameters()
|
|
|
+
|
|
|
+ def _init_hyper_parameters(self):
|
|
|
+ self.trigram_d = self.config.get("hyper_parameters.trigram_d")
|
|
|
+ self.neg_num = self.config.get("hyper_parameters.neg_num")
|
|
|
+ self.hidden_layers = self.config.get("hyper_parameters.fc_sizes")
|
|
|
+ self.hidden_acts = self.config.get("hyper_parameters.fc_acts")
|
|
|
+ self.learning_rate = self.config.get("hyper_parameters.learning_rate")
|
|
|
+ self.slice_end = self.config.get("hyper_parameters.slice_end")
|
|
|
+ self.learning_rate = self.config.get(
|
|
|
+ "hyper_parameters.optimizer.learning_rate")
|
|
|
+
|
|
|
+ def create_feeds(self, is_infer=False):
|
|
|
+ query = paddle.static.data(
|
|
|
+ name="query", shape=[-1, self.trigram_d], dtype='float32')
|
|
|
+ self.prune_feed_vars = [query]
|
|
|
+
|
|
|
+ doc_pos = paddle.static.data(
|
|
|
+ name="doc_pos", shape=[-1, self.trigram_d], dtype='float32')
|
|
|
+
|
|
|
+ if is_infer:
|
|
|
+ return [query, doc_pos]
|
|
|
+
|
|
|
+ doc_negs = [
|
|
|
+ paddle.static.data(
|
|
|
+ name="doc_neg_" + str(i),
|
|
|
+ shape=[-1, self.trigram_d],
|
|
|
+ dtype="float32") for i in range(self.neg_num)
|
|
|
+ ]
|
|
|
+ feeds_list = [query, doc_pos] + doc_negs
|
|
|
+ return feeds_list
|
|
|
+
|
|
|
+ def net(self, input, is_infer=False):
|
|
|
+ dssm_model = DSSMLayer(self.trigram_d, self.neg_num, self.slice_end,
|
|
|
+ self.hidden_layers, self.hidden_acts)
|
|
|
+ R_Q_D_p, hit_prob = dssm_model.forward(input, is_infer)
|
|
|
+
|
|
|
+ self.inference_target_var = R_Q_D_p
|
|
|
+ self.prune_target_var = dssm_model.query_fc
|
|
|
+ self.train_dump_fields = [dssm_model.query_fc, R_Q_D_p]
|
|
|
+ self.train_dump_params = dssm_model.params
|
|
|
+ self.infer_dump_fields = [dssm_model.doc_pos_fc]
|
|
|
+ if is_infer:
|
|
|
+ fetch_dict = {'query_doc_sim': R_Q_D_p}
|
|
|
+ return fetch_dict
|
|
|
+
|
|
|
+
|
|
|
+ loss = -paddle.sum(paddle.log(hit_prob), axis=-1)
|
|
|
+ avg_cost = paddle.mean(x=loss)
|
|
|
+
|
|
|
+ self._cost = avg_cost
|
|
|
+ fetch_dict = {'Loss': avg_cost}
|
|
|
+ return fetch_dict
|
|
|
+
|
|
|
+
|
|
|
+ def create_optimizer(self, strategy=None):
|
|
|
+ optimizer = paddle.optimizer.Adam(
|
|
|
+ learning_rate=self.learning_rate, lazy_mode=True)
|
|
|
+ if strategy != None:
|
|
|
+ import paddle.distributed.fleet as fleet
|
|
|
+ optimizer = fleet.distributed_optimizer(optimizer, strategy)
|
|
|
+ optimizer.minimize(self._cost)
|
|
|
+
|
|
|
+ def infer_net(self, input):
|
|
|
+ return self.net(input, is_infer=True)
|