static_model.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import math
  2. import paddle
  3. from net import DSSMLayer
  4. class StaticModel():
  5. def __init__(self, config):
  6. self.cost = None
  7. self.config = config
  8. self._init_hyper_parameters()
  9. self.is_infer = False
  10. def _init_hyper_parameters(self):
  11. # 修改超参数初始化
  12. self.feature_nums = self.config.get("hyper_parameters.feature_nums", [5,5,5,5,5])
  13. self.embedding_dim = self.config.get("hyper_parameters.embedding_dim", 8)
  14. self.output_dim = self.config.get("hyper_parameters.output_dim", 16)
  15. self.hidden_layers = self.config.get("hyper_parameters.hidden_layers", [40, 32])
  16. self.hidden_acts = self.config.get("hyper_parameters.hidden_acts", ["relu", "relu"])
  17. self.learning_rate = self.config.get("hyper_parameters.optimizer.learning_rate", 0.001)
  18. self.margin = self.config.get("hyper_parameters.margin", 0.0) # 用于损失函数的margin参数
  19. self.feature_num = len(self.feature_nums)
  20. self.is_infer = self.config.get("hyper_parameters.is_infer", False)
  21. def create_feeds(self, is_infer=False):
  22. # 定义输入数据占位符
  23. feeds_list = []
  24. if not is_infer:
  25. label = paddle.static.data(
  26. name="label", shape=[-1, 1], dtype='float32')
  27. feeds_list.append(label)
  28. left_features = paddle.static.data(
  29. name="left_features", shape=[-1, self.feature_num], dtype='float32')
  30. feeds_list.append(left_features)
  31. right_features = paddle.static.data(
  32. name="right_features", shape=[-1, self.feature_num], dtype='float32')
  33. feeds_list.append(right_features)
  34. else:
  35. #sample_id = paddle.static.data(
  36. # name="sample_id", shape=[-1, 1], dtype='int64')
  37. #feeds_list.append(sample_id)
  38. left_features = paddle.static.data(
  39. name="left_features", shape=[-1, self.feature_num], dtype='float32')
  40. feeds_list.append(left_features)
  41. return feeds_list
  42. def net(self, input, is_infer=False):
  43. # 创建模型实例
  44. dssm_model = DSSMLayer(
  45. feature_nums=self.feature_nums,
  46. embedding_dim=self.embedding_dim,
  47. output_dim=self.output_dim,
  48. hidden_layers=self.hidden_layers,
  49. hidden_acts=self.hidden_acts
  50. )
  51. if is_infer:
  52. left_features = input[0]
  53. left_vec = dssm_model(left_features,None,is_infer=True)
  54. self.inference_target_var = left_vec
  55. fetch_dict = {
  56. 'left_vector': left_vec
  57. }
  58. return fetch_dict
  59. else:
  60. label,left_features, right_features = input
  61. print(f"Label shape: {label.shape}")
  62. # 获取相似度和特征向量
  63. sim_score, left_vec, right_vec = dssm_model(left_features, right_features)
  64. self.inference_target_var = left_vec
  65. # self.left_vector = left_vec
  66. # self.right_vector = right_vec
  67. # 计算损失
  68. # 使用带margin的二元交叉熵损失
  69. pos_mask = paddle.cast(label > 0.5, 'float32')
  70. neg_mask = 1.0 - pos_mask
  71. positive_loss = -pos_mask * paddle.log(paddle.clip(sim_score, 1e-8, 1.0))
  72. negative_loss = -neg_mask * paddle.log(paddle.clip(1 - sim_score + self.margin, 1e-8, 1.0))
  73. loss = positive_loss + negative_loss
  74. avg_cost = paddle.mean(loss)
  75. self._cost = avg_cost
  76. # 计算accuracy
  77. predictions = paddle.cast(sim_score > 0.5, 'float32')
  78. accuracy = paddle.mean(paddle.cast(paddle.equal(predictions, label), 'float32'))
  79. fetch_dict = {
  80. 'loss': avg_cost,
  81. 'accuracy': accuracy,
  82. #'similarity': sim_score,
  83. #'left_vector': left_vec,
  84. #'right_vector': right_vec
  85. }
  86. return fetch_dict
  87. def create_optimizer(self, strategy=None,is_infer=False):
  88. optimizer = paddle.optimizer.Adam(
  89. learning_rate=self.learning_rate)
  90. if strategy is not None:
  91. import paddle.distributed.fleet as fleet
  92. optimizer = fleet.distributed_optimizer(optimizer, strategy)
  93. if is_infer:
  94. zero_var = paddle.zeros(shape=[1], dtype='float32')
  95. optimizer.minimize(paddle.mean(zero_var))
  96. else:
  97. optimizer.minimize(self._cost)
  98. def infer_net(self, input):
  99. return self.net(input, is_infer=True)