benchmark_reader.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import yaml
  16. import six
  17. import os
  18. import copy
  19. import xxhash
  20. import paddle.distributed.fleet as fleet
  21. import logging
  22. cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  23. cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
  24. cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
  25. hash_dim_ = 1000001
  26. continuous_range_ = range(1, 14)
  27. categorical_range_ = range(14, 40)
  28. logging.basicConfig(
  29. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  30. logger = logging.getLogger(__name__)
  31. class Reader(fleet.MultiSlotDataGenerator):
  32. def init(self, config):
  33. self.config = config
  34. def line_process(self, line):
  35. features = line.rstrip('\n').split('\t')
  36. dense_feature = []
  37. sparse_feature = []
  38. for idx in continuous_range_:
  39. if features[idx] == "":
  40. dense_feature.append(0.0)
  41. else:
  42. dense_feature.append(
  43. (float(features[idx]) - cont_min_[idx - 1]) /
  44. cont_diff_[idx - 1])
  45. for idx in categorical_range_:
  46. sparse_feature.append([
  47. xxhash.xxh32(str(idx) + features[idx]).intdigest() % hash_dim_
  48. ])
  49. label = [int(features[0])]
  50. return [label] + sparse_feature + [dense_feature]
  51. def generate_sample(self, line):
  52. "Dataset Generator"
  53. def reader():
  54. input_data = self.line_process(line)
  55. feature_name = ["dense_input"]
  56. for idx in categorical_range_:
  57. feature_name.append("C" + str(idx - 13))
  58. feature_name.append("label")
  59. yield zip(feature_name, input_data)
  60. return reader
  61. def dataloader(self, file_list):
  62. "DataLoader Pyreader Generator"
  63. def reader():
  64. for file in file_list:
  65. with open(file, 'r') as f:
  66. for line in f:
  67. input_data = self.line_process(line)
  68. yield input_data
  69. return reader
  70. if __name__ == "__main__":
  71. yaml_path = sys.argv[1]
  72. utils_path = sys.argv[2]
  73. sys.path.append(utils_path)
  74. import common_ps
  75. yaml_helper = common_ps.YamlHelper()
  76. config = yaml_helper.load_yaml(yaml_path)
  77. r = Reader()
  78. r.init(config)
  79. r.run_from_stdin()