widedeep_v13_3.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. #
  5. # Copyright © 2025 StrayWarrior <i@straywarrior.com>
  6. """
  7. 1.删除容易导致偏差的viewall特征
  8. 2.删除分桶不均匀的cpa特征
  9. 3.减少dense特征
  10. 4.增加U-I交叉统计
  11. 5.增加线性部分dense
  12. 6.减少wide部分embedding
  13. 7.减少部分bucket size
  14. 8.使用protobuf
  15. 9.调整embedding variable (PAI平台目前还不支持)
  16. """
  17. import re
  18. from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
  19. from easy_rec.python.protos.train_pb2 import TrainConfig
  20. from easy_rec.python.protos.eval_pb2 import EvalConfig, AUC, EvalMetrics
  21. from easy_rec.python.protos.dataset_pb2 import DatasetConfig
  22. from easy_rec.python.protos.feature_config_pb2 import FeatureConfig, FeatureGroupConfig, WideOrDeep, EVParams
  23. from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel
  24. from easy_rec.python.protos.deepfm_pb2 import DeepFM
  25. from easy_rec.python.protos.dnn_pb2 import DNN
  26. from easy_rec.python.protos.export_pb2 import ExportConfig
  27. from easy_rec.python.protos.optimizer_pb2 import Optimizer, AdamOptimizer, LearningRate, ConstantLearningRate
  28. from google.protobuf import text_format
  29. raw_input = open("data_fields_v3.config").readlines()
  30. input_fields = dict(
  31. map(lambda x: (x[0], x[1]),
  32. map(lambda x: x.strip().split(' '), raw_input)))
  33. def read_features(filename, excludes=None):
  34. features = open(filename).readlines()
  35. features = [name.strip().lower() for name in features]
  36. if excludes:
  37. for x in excludes:
  38. if x in features:
  39. features.remove(x)
  40. return features
  41. exclude_features = ['viewall', 'cpa']
  42. dense_features = read_features("features_top300.config", exclude_features)
  43. top_dense_features = read_features('features_top50.config', exclude_features)
  44. sparse_features = [
  45. "cid", "adid", "adverid",
  46. "region", "city", "brand",
  47. "vid", "cate1", "cate2",
  48. "apptype", "hour", "hour_quarter", "root_source_scene", "root_source_channel", "is_first_layer", "title_split", "user_has_conver_1y",
  49. "user_adverid_view_3d", "user_adverid_view_7d", "user_adverid_view_30d",
  50. "user_adverid_click_3d", "user_adverid_click_7d", "user_adverid_click_30d",
  51. "user_adverid_conver_3d", "user_adverid_conver_7d", "user_adverid_conver_30d",
  52. "user_skuid_view_3d", "user_skuid_view_7d", "user_skuid_view_30d",
  53. "user_skuid_click_3d", "user_skuid_click_7d", "user_skuid_click_30d",
  54. "user_skuid_conver_3d", "user_skuid_conver_7d", "user_skuid_conver_30d"
  55. ]
  56. tag_features = [
  57. "user_vid_return_tags_2h", "user_vid_return_tags_1d", "user_vid_return_tags_3d",
  58. "user_vid_return_tags_7d", "user_vid_return_tags_14d"
  59. ]
  60. seq_features = [
  61. "user_cid_click_list", "user_cid_conver_list"
  62. ]
  63. input_type_map = {
  64. 'BIGINT': DatasetConfig.FieldType.INT64,
  65. 'DOUBLE': DatasetConfig.FieldType.DOUBLE,
  66. 'STRING': DatasetConfig.FieldType.STRING
  67. }
  68. use_ev_features = [
  69. "cid", "adid", "adverid", "vid"
  70. ]
  71. bucket_size_map = {
  72. 'adverid': 100000,
  73. 'region': 1000,
  74. 'city': 10000,
  75. 'brand': 10000,
  76. 'cate1': 10000,
  77. 'cate2': 10000,
  78. 'apptype': 1000,
  79. 'hour': 1000, # 实际上可以直接指定词表
  80. 'hour_quarter': 4000,
  81. 'root_source_scene': 100,
  82. 'root_source_channel': 1000,
  83. 'is_first_layer': 100,
  84. 'user_has_conver_1y': 100,
  85. }
  86. def create_config():
  87. config = EasyRecConfig()
  88. # 训练配置
  89. train_config = TrainConfig()
  90. # 配置多个优化器
  91. optimizers = [
  92. (0.0010, False), # wide参数
  93. (0.0006, False), # dense参数
  94. (0.002, False) # deep embedding参数
  95. ]
  96. for lr, use_moving_avg in optimizers:
  97. optimizer = Optimizer()
  98. adam_optimizer = AdamOptimizer()
  99. learning_rate = LearningRate()
  100. constant_lr = ConstantLearningRate()
  101. constant_lr.learning_rate = lr
  102. learning_rate.constant_learning_rate.CopyFrom(constant_lr)
  103. adam_optimizer.learning_rate.CopyFrom(learning_rate)
  104. optimizer.adam_optimizer.CopyFrom(adam_optimizer)
  105. optimizer.use_moving_average = use_moving_avg
  106. train_config.optimizer_config.append(optimizer)
  107. train_config.num_steps = 200000
  108. train_config.sync_replicas = True
  109. train_config.save_checkpoints_steps = 1100
  110. train_config.log_step_count_steps = 100
  111. train_config.save_summary_steps = 100
  112. config.train_config.CopyFrom(train_config)
  113. # 评估配置
  114. eval_config = EvalConfig()
  115. metrics_set = EvalMetrics()
  116. metrics_set.auc.SetInParent()
  117. eval_config.metrics_set.append(metrics_set)
  118. eval_config.eval_online = True
  119. eval_config.eval_interval_secs = 120
  120. config.eval_config.CopyFrom(eval_config)
  121. # 数据配置
  122. data_config = DatasetConfig()
  123. data_config.batch_size = 512
  124. data_config.num_epochs = 1
  125. data_config.prefetch_size = 32
  126. data_config.input_type = DatasetConfig.InputType.OdpsInputV2
  127. # 添加输入字段
  128. for name in input_fields:
  129. input_field = DatasetConfig.Field()
  130. input_field.input_name = name
  131. input_field.input_type = input_type_map[input_fields[name]]
  132. if name in dense_features:
  133. input_field.default_val = "0"
  134. data_config.input_fields.append(input_field)
  135. # 添加标签字段
  136. data_config.label_fields.append("has_conversion")
  137. config.data_config.CopyFrom(data_config)
  138. # 特征配置
  139. feature_configs = []
  140. # Dense特征配置
  141. boundaries = [ x / 100 for x in range(0, 101)]
  142. for name in dense_features:
  143. feature_config = FeatureConfig()
  144. feature_config.input_names.append(name)
  145. feature_config.feature_type = FeatureConfig.RawFeature
  146. feature_config.boundaries.extend(boundaries)
  147. feature_config.embedding_dim = 6
  148. feature_configs.append(feature_config)
  149. # Sparse特征配置
  150. for name in sparse_features:
  151. feature_config = FeatureConfig()
  152. feature_config.input_names.append(name)
  153. feature_config.feature_type = FeatureConfig.IdFeature
  154. # 只有INT64类型的特征才能使用embedding variable特性
  155. if name in use_ev_features:
  156. if input_type_map[input_fields[name]] != DatasetConfig.FieldType.INT64:
  157. raise ValueError(f"Feature {name} must be of type INT64 to use embedding variable.")
  158. feature_config.ev_params.filter_freq = 2
  159. else:
  160. feature_config.hash_bucket_size = bucket_size_map.get(name, 1000000)
  161. feature_config.embedding_dim = 6
  162. feature_configs.append(feature_config)
  163. # Tag特征配置
  164. for name in tag_features + seq_features:
  165. feature_config = FeatureConfig()
  166. feature_config.input_names.append(name)
  167. feature_config.feature_type = FeatureConfig.TagFeature
  168. feature_config.hash_bucket_size = bucket_size_map.get(name, 1000000)
  169. feature_config.embedding_dim = 6
  170. feature_config.separator = ','
  171. feature_configs.append(feature_config)
  172. config.feature_configs.extend(feature_configs)
  173. # 模型配置
  174. model_config = EasyRecModel()
  175. model_config.model_class = "DeepFM"
  176. # Wide特征组
  177. wide_group = FeatureGroupConfig()
  178. wide_group.group_name = 'wide'
  179. wide_group.feature_names.extend(dense_features + sparse_features)
  180. wide_group.wide_deep = WideOrDeep.WIDE
  181. model_config.feature_groups.append(wide_group)
  182. # Deep特征组
  183. deep_group = FeatureGroupConfig()
  184. deep_group.group_name = 'deep'
  185. deep_group.feature_names.extend(top_dense_features + sparse_features + tag_features + seq_features)
  186. deep_group.wide_deep = WideOrDeep.DEEP
  187. model_config.feature_groups.append(deep_group)
  188. # DeepFM配置
  189. deepfm = DeepFM()
  190. deepfm.wide_output_dim = 2
  191. # DNN配置
  192. dnn = DNN()
  193. dnn.hidden_units.extend([256, 128, 64])
  194. deepfm.dnn.CopyFrom(dnn)
  195. # Final DNN配置
  196. final_dnn = DNN()
  197. final_dnn.hidden_units.extend([64, 32])
  198. deepfm.final_dnn.CopyFrom(final_dnn)
  199. deepfm.l2_regularization = 1e-5
  200. model_config.deepfm.CopyFrom(deepfm)
  201. model_config.embedding_regularization = 1e-6
  202. config.model_config.CopyFrom(model_config)
  203. # 导出配置
  204. export_config = ExportConfig()
  205. export_config.exporter_type = "final"
  206. config.export_config.CopyFrom(export_config)
  207. return config
  208. def merge_repeated_elements(msg_str, field_name):
  209. msg_str = re.sub(
  210. fr'( +{field_name}: [^\n]+\n)+',
  211. lambda m: '{}{}: [{}]\n'.format(
  212. m.group(0).split(field_name)[0],
  213. field_name,
  214. ', '.join(re.findall(fr'{field_name}: ([^\n]+)', m.group(0)))
  215. ),
  216. msg_str
  217. )
  218. return msg_str
  219. def main():
  220. config = create_config()
  221. msg_str = text_format.MessageToString(config)
  222. msg_str = merge_repeated_elements(msg_str, 'boundaries')
  223. msg_str = merge_repeated_elements(msg_str, 'hidden_units')
  224. print(msg_str)
  225. if __name__ == '__main__':
  226. main()