widedeep_v13_4.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. #
  5. # Copyright © 2025 StrayWarrior <i@straywarrior.com>
  6. """
  7. 1.删除容易导致偏差的viewall特征
  8. 2.删除分桶不均匀的cpa特征
  9. 3.减少dense特征
  10. 4.增加U-I交叉统计
  11. 5.增加线性部分dense
  12. 6.减少wide部分embedding
  13. 7.减少部分bucket size
  14. 8.使用protobuf
  15. 9.调整embedding variable (PAI平台目前还不支持)
  16. """
  17. import os
  18. os.environ['PROCESSOR_TEST'] = "1"
  19. import re
  20. from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
  21. from easy_rec.python.protos.train_pb2 import TrainConfig
  22. from easy_rec.python.protos.eval_pb2 import EvalConfig, AUC, EvalMetrics
  23. from easy_rec.python.protos.dataset_pb2 import DatasetConfig
  24. from easy_rec.python.protos.feature_config_pb2 import FeatureConfig, FeatureGroupConfig, WideOrDeep, EVParams
  25. from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel
  26. from easy_rec.python.protos.deepfm_pb2 import DeepFM
  27. from easy_rec.python.protos.dnn_pb2 import DNN
  28. from easy_rec.python.protos.export_pb2 import ExportConfig
  29. from easy_rec.python.protos.optimizer_pb2 import Optimizer, AdamOptimizer, LearningRate, ConstantLearningRate
  30. from google.protobuf import text_format
  31. raw_input = open("data_fields_v3.config").readlines()
  32. input_fields = dict(
  33. map(lambda x: (x[0], x[1]),
  34. map(lambda x: x.strip().split(' '), raw_input)))
  35. def read_features(filename, excludes=None):
  36. features = open(filename).readlines()
  37. features = [name.strip().lower() for name in features]
  38. if excludes:
  39. for x in excludes:
  40. if x in features:
  41. features.remove(x)
  42. return features
  43. exclude_features = ['viewall', 'cpa']
  44. dense_features = read_features("features_top300.config", exclude_features)
  45. top_dense_features = read_features('features_top50.config', exclude_features)
  46. sparse_features = [
  47. "cid", "adid", "adverid",
  48. "region", "city", "brand",
  49. "vid", "cate1", "cate2",
  50. "apptype", "hour", "hour_quarter", "root_source_scene", "root_source_channel", "is_first_layer", "user_has_conver_1y",
  51. "user_adverid_view_3d", "user_adverid_view_7d", "user_adverid_view_30d",
  52. "user_adverid_click_3d", "user_adverid_click_7d", "user_adverid_click_30d",
  53. "user_adverid_conver_3d", "user_adverid_conver_7d", "user_adverid_conver_30d",
  54. "user_skuid_view_3d", "user_skuid_view_7d", "user_skuid_view_30d",
  55. "user_skuid_click_3d", "user_skuid_click_7d", "user_skuid_click_30d",
  56. "user_skuid_conver_3d", "user_skuid_conver_7d", "user_skuid_conver_30d"
  57. ]
  58. tag_features = [
  59. "user_vid_return_tags_2h", "user_vid_return_tags_1d", "user_vid_return_tags_3d",
  60. "user_vid_return_tags_7d", "user_vid_return_tags_14d",
  61. "user_conver_ad_class", "title_split"
  62. ]
  63. seq_features = [
  64. "user_cid_click_list", "user_cid_conver_list"
  65. ]
  66. input_type_map = {
  67. 'BIGINT': DatasetConfig.FieldType.INT64,
  68. 'DOUBLE': DatasetConfig.FieldType.DOUBLE,
  69. 'STRING': DatasetConfig.FieldType.STRING
  70. }
  71. use_ev_features = [
  72. # PAI上运行的TF疑似不支持
  73. ]
  74. bucket_size_map = {
  75. 'adverid': 100000,
  76. 'region': 1000,
  77. 'city': 10000,
  78. 'brand': 10000,
  79. 'cate1': 10000,
  80. 'cate2': 10000,
  81. 'apptype': 1000,
  82. 'hour': 1000, # 实际上可以直接指定词表
  83. 'hour_quarter': 4000,
  84. 'root_source_scene': 100,
  85. 'root_source_channel': 1000,
  86. 'is_first_layer': 100,
  87. 'user_has_conver_1y': 100,
  88. 'user_conver_ad_class': 10000
  89. }
  90. def create_config():
  91. config = EasyRecConfig()
  92. # 训练配置
  93. train_config = TrainConfig()
  94. # 配置多个优化器
  95. optimizers = [
  96. (0.0010, False), # wide参数
  97. (0.0006, False), # dense参数
  98. (0.002, False) # deep embedding参数
  99. ]
  100. for lr, use_moving_avg in optimizers:
  101. optimizer = Optimizer()
  102. adam_optimizer = AdamOptimizer()
  103. learning_rate = LearningRate()
  104. constant_lr = ConstantLearningRate()
  105. constant_lr.learning_rate = lr
  106. learning_rate.constant_learning_rate.CopyFrom(constant_lr)
  107. adam_optimizer.learning_rate.CopyFrom(learning_rate)
  108. optimizer.adam_optimizer.CopyFrom(adam_optimizer)
  109. optimizer.use_moving_average = use_moving_avg
  110. train_config.optimizer_config.append(optimizer)
  111. train_config.num_steps = 200000
  112. train_config.sync_replicas = True
  113. train_config.save_checkpoints_steps = 1100
  114. train_config.log_step_count_steps = 100
  115. train_config.save_summary_steps = 100
  116. config.train_config.CopyFrom(train_config)
  117. # 评估配置
  118. eval_config = EvalConfig()
  119. metrics_set = EvalMetrics()
  120. metrics_set.auc.SetInParent()
  121. eval_config.metrics_set.append(metrics_set)
  122. eval_config.eval_online = True
  123. eval_config.eval_interval_secs = 120
  124. config.eval_config.CopyFrom(eval_config)
  125. # 数据配置
  126. data_config = DatasetConfig()
  127. data_config.batch_size = 512
  128. data_config.num_epochs = 1
  129. data_config.prefetch_size = 32
  130. data_config.input_type = DatasetConfig.InputType.OdpsInputV2
  131. # 添加输入字段
  132. for name in input_fields:
  133. input_field = DatasetConfig.Field()
  134. input_field.input_name = name
  135. input_field.input_type = input_type_map[input_fields[name]]
  136. if name in dense_features:
  137. input_field.default_val = "0"
  138. data_config.input_fields.append(input_field)
  139. # 添加标签字段
  140. data_config.label_fields.append("has_conversion")
  141. config.data_config.CopyFrom(data_config)
  142. # 特征配置
  143. feature_configs = []
  144. # Dense特征配置
  145. boundaries = [ x / 100 for x in range(0, 101)]
  146. for name in dense_features:
  147. feature_config = FeatureConfig()
  148. feature_config.input_names.append(name)
  149. if name not in input_fields:
  150. raise Exception(f"{name} not found in input fields")
  151. feature_config.feature_type = FeatureConfig.RawFeature
  152. feature_config.boundaries.extend(boundaries)
  153. feature_config.embedding_dim = 6
  154. feature_configs.append(feature_config)
  155. # Sparse特征配置
  156. for name in sparse_features:
  157. feature_config = FeatureConfig()
  158. feature_config.input_names.append(name)
  159. if name not in input_fields:
  160. raise Exception(f"{name} not found in input fields")
  161. feature_config.feature_type = FeatureConfig.IdFeature
  162. # 只有INT64类型的特征才能使用embedding variable特性
  163. if name in use_ev_features:
  164. if input_type_map[input_fields[name]] != DatasetConfig.FieldType.INT64:
  165. raise ValueError(f"Feature {name} must be of type INT64 to use embedding variable.")
  166. feature_config.ev_params.filter_freq = 2
  167. else:
  168. feature_config.hash_bucket_size = bucket_size_map.get(name, 1000000)
  169. feature_config.embedding_dim = 6
  170. feature_configs.append(feature_config)
  171. # Tag特征配置
  172. for name in tag_features + seq_features:
  173. feature_config = FeatureConfig()
  174. feature_config.input_names.append(name)
  175. if name not in input_fields:
  176. raise Exception(f"{name} not found in input fields")
  177. feature_config.feature_type = FeatureConfig.TagFeature
  178. feature_config.hash_bucket_size = bucket_size_map.get(name, 1000000)
  179. feature_config.embedding_dim = 6
  180. feature_config.separator = ','
  181. feature_configs.append(feature_config)
  182. config.feature_configs.extend(feature_configs)
  183. # 模型配置
  184. model_config = EasyRecModel()
  185. model_config.model_class = "DeepFM"
  186. # Wide特征组
  187. wide_group = FeatureGroupConfig()
  188. wide_group.group_name = 'wide'
  189. wide_group.feature_names.extend(dense_features + sparse_features)
  190. wide_group.wide_deep = WideOrDeep.WIDE
  191. model_config.feature_groups.append(wide_group)
  192. # Deep特征组
  193. deep_group = FeatureGroupConfig()
  194. deep_group.group_name = 'deep'
  195. deep_group.feature_names.extend(top_dense_features + sparse_features + tag_features + seq_features)
  196. deep_group.wide_deep = WideOrDeep.DEEP
  197. model_config.feature_groups.append(deep_group)
  198. # DeepFM配置
  199. deepfm = DeepFM()
  200. deepfm.wide_output_dim = 2
  201. # DNN配置
  202. dnn = DNN()
  203. dnn.hidden_units.extend([256, 128, 64])
  204. deepfm.dnn.CopyFrom(dnn)
  205. # Final DNN配置
  206. final_dnn = DNN()
  207. final_dnn.hidden_units.extend([64, 32])
  208. deepfm.final_dnn.CopyFrom(final_dnn)
  209. deepfm.l2_regularization = 1e-5
  210. model_config.deepfm.CopyFrom(deepfm)
  211. model_config.embedding_regularization = 1e-6
  212. config.model_config.CopyFrom(model_config)
  213. # 导出配置
  214. export_config = ExportConfig()
  215. export_config.exporter_type = "final"
  216. config.export_config.CopyFrom(export_config)
  217. return config
  218. def merge_repeated_elements(msg_str, field_name):
  219. msg_str = re.sub(
  220. fr'( +{field_name}: [^\n]+\n)+',
  221. lambda m: '{}{}: [{}]\n'.format(
  222. m.group(0).split(field_name)[0],
  223. field_name,
  224. ', '.join(re.findall(fr'{field_name}: ([^\n]+)', m.group(0)))
  225. ),
  226. msg_str
  227. )
  228. return msg_str
  229. def main():
  230. config = create_config()
  231. msg_str = text_format.MessageToString(config)
  232. msg_str = merge_repeated_elements(msg_str, 'boundaries')
  233. msg_str = merge_repeated_elements(msg_str, 'hidden_units')
  234. print(msg_str)
  235. if __name__ == '__main__':
  236. main()