|
@@ -10,13 +10,14 @@ class StaticModel():
|
|
|
|
|
|
def _init_hyper_parameters(self):
|
|
|
# 修改超参数初始化
|
|
|
- self.feature_num = self.config.get("hyper_parameters.feature_nums", [5,5,5,5,5])
|
|
|
+ self.feature_nums = self.config.get("hyper_parameters.feature_nums", [5,5,5,5,5])
|
|
|
self.embedding_dim = self.config.get("hyper_parameters.embedding_dim", 8)
|
|
|
self.output_dim = self.config.get("hyper_parameters.output_dim", 16)
|
|
|
self.hidden_layers = self.config.get("hyper_parameters.hidden_layers", [64, 32])
|
|
|
self.hidden_acts = self.config.get("hyper_parameters.hidden_acts", ["relu", "relu"])
|
|
|
self.learning_rate = self.config.get("hyper_parameters.optimizer.learning_rate", 0.001)
|
|
|
self.margin = self.config.get("hyper_parameters.margin", 0.3) # 用于损失函数的margin参数
|
|
|
+ self.feature_num = len(self.feature_nums)
|
|
|
|
|
|
def create_feeds(self, is_infer=False):
|
|
|
# 定义输入数据占位符
|