|
@@ -7,6 +7,7 @@ class StaticModel():
|
|
|
self.cost = None
|
|
|
self.config = config
|
|
|
self._init_hyper_parameters()
|
|
|
+ self.is_infer = False
|
|
|
|
|
|
def _init_hyper_parameters(self):
|
|
|
# 修改超参数初始化
|
|
@@ -18,7 +19,7 @@ class StaticModel():
|
|
|
self.learning_rate = self.config.get("hyper_parameters.optimizer.learning_rate", 0.001)
|
|
|
self.margin = self.config.get("hyper_parameters.margin", 0.3) # 用于损失函数的margin参数
|
|
|
self.feature_num = len(self.feature_nums)
|
|
|
-
|
|
|
+ self.is_infer = self.config.get("hyper_parameters.is_infer", False)
|
|
|
def create_feeds(self, is_infer=False):
|
|
|
# 定义输入数据占位符
|
|
|
|
|
@@ -53,7 +54,7 @@ class StaticModel():
|
|
|
hidden_layers=self.hidden_layers,
|
|
|
hidden_acts=self.hidden_acts
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
if is_infer:
|
|
|
sample_id,left_features = input
|
|
|
left_vec = dssm_model(left_features,None,is_infer=True)
|
|
@@ -98,13 +99,16 @@ class StaticModel():
|
|
|
}
|
|
|
return fetch_dict
|
|
|
|
|
|
- def create_optimizer(self, strategy=None):
|
|
|
+ def create_optimizer(self, strategy=None,is_infer=False):
|
|
|
optimizer = paddle.optimizer.Adam(
|
|
|
learning_rate=self.learning_rate)
|
|
|
if strategy is not None:
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
optimizer = fleet.distributed_optimizer(optimizer, strategy)
|
|
|
- optimizer.minimize(self._cost)
|
|
|
+ if is_infer:
|
|
|
+ optimizer.minimize(paddle.mean(0))
|
|
|
+ else:
|
|
|
+ optimizer.minimize(self._cost)
|
|
|
|
|
|
def infer_net(self, input):
|
|
|
return self.net(input, is_infer=True)
|