often 5 ay önce
ebeveyn
işleme
503a5c46e5

+ 2 - 1
recommend-model-produce/src/main/python/models/dssm/static_model.py

@@ -10,13 +10,14 @@ class StaticModel():
 
     def _init_hyper_parameters(self):
         # 修改超参数初始化
-        self.feature_num = self.config.get("hyper_parameters.feature_nums", [5,5,5,5,5])
+        self.feature_nums = self.config.get("hyper_parameters.feature_nums", [5,5,5,5,5])
         self.embedding_dim = self.config.get("hyper_parameters.embedding_dim", 8)
         self.output_dim = self.config.get("hyper_parameters.output_dim", 16)
         self.hidden_layers = self.config.get("hyper_parameters.hidden_layers", [64, 32])
         self.hidden_acts = self.config.get("hyper_parameters.hidden_acts", ["relu", "relu"])
         self.learning_rate = self.config.get("hyper_parameters.optimizer.learning_rate", 0.001)
         self.margin = self.config.get("hyper_parameters.margin", 0.3)  # 用于损失函数的margin参数
+        self.feature_num = len(self.feature_nums)
 
     def create_feeds(self, is_infer=False):
         # 定义输入数据占位符