infer.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import os
  2. import sys
  3. __dir__ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  4. #sys.path.append(__dir__)
  5. sys.path.append(os.path.join(__dir__,"tools"))
  6. import numpy as np
  7. import json
  8. from concurrent.futures import ThreadPoolExecutor
  9. from utils.oss_client import HangZhouOSSClient
  10. import utils.compress as compress
  11. from utils.my_hdfs_client import MyHDFSClient
  12. import paddle.inference as paddle_infer
  13. # Hadoop 安装目录和配置信息
  14. hadoop_home = "/app/env/hadoop-3.2.4"
  15. configs = {
  16. "fs.defaultFS": "hdfs://192.168.141.208:9000",
  17. "hadoop.job.ugi": ""
  18. }
  19. hdfs_client = MyHDFSClient(hadoop_home, configs)
  20. def download_and_extract_model(init_model_path, oss_client, oss_object_name):
  21. """下载并解压模型"""
  22. model_tar_path = "model.tar.gz"
  23. oss_client.get_object_to_file(oss_object_name, model_tar_path)
  24. compress.uncompress_tar(model_tar_path, init_model_path)
  25. assert os.path.exists(init_model_path)
  26. def create_paddle_predictor(model_file, params_file):
  27. """创建PaddlePaddle的predictor"""
  28. config = paddle_infer.Config(model_file, params_file)
  29. predictor = paddle_infer.create_predictor(config)
  30. return predictor
  31. def process_file(file_path, model_file, params_file):
  32. """处理单个文件"""
  33. predictor = create_paddle_predictor(model_file, params_file)
  34. ret, out = hdfs_client._run_cmd(f"text {file_path}")
  35. input_data = {}
  36. for line in out:
  37. sample_values = line.rstrip('\n').split('\t')
  38. vid, left_features_str = sample_values
  39. left_features = [float(x) for x in left_features_str.split(',')]
  40. input_data[vid] = left_features
  41. result = []
  42. for k, v in input_data.items():
  43. v2 = np.array([v], dtype=np.float32)
  44. input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
  45. input_handle.copy_from_cpu(v2)
  46. predictor.run()
  47. output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
  48. output_data = output_handle.copy_to_cpu()
  49. result.append(k + "\t" + str(output_data.tolist()[0]))
  50. return result
  51. def write_results(results, output_file):
  52. """将结果写入文件"""
  53. with open(output_file, 'w') as json_file:
  54. for s in results:
  55. json_file.write(s + "\n")
  56. def thread_task(name, file_list, model_file, params_file):
  57. """线程任务"""
  58. print(f"Thread {name}: starting file_list:{file_list}")
  59. results = []
  60. i=0
  61. for file_path in file_list:
  62. i=i+1
  63. count=len(file_list)
  64. print(f"Thread {name}: starting file:{file_path} {i}/{count}")
  65. results.extend(process_file(file_path, model_file, params_file))
  66. file_name, file_suffix = os.path.splitext(os.path.basename(file_path))
  67. output_file = f"/app/vec-{file_name}.json"
  68. write_results(results, output_file)
  69. compress.compress_file_tar(output_file, f"{output_file}.tar.gz")
  70. hdfs_client.delete(f"/dyp/vec/{file_name}.gz")
  71. hdfs_client.upload(f"{output_file}.tar.gz", f"/dyp/vec/{file_name}.gz", multi_processes=1, overwrite=False)
  72. results=[]
  73. print(f"Thread {name}: ending file:{file_path} {i}/{count}")
  74. print(f"Thread {name}: finishing")
  75. def main():
  76. init_model_path = "/app/output_model_dssm"
  77. client = HangZhouOSSClient("art-recommend")
  78. oss_object_name = "dyp/dssm.tar.gz"
  79. download_and_extract_model(init_model_path, client, oss_object_name)
  80. model_file = os.path.join(init_model_path, "dssm.pdmodel")
  81. params_file = os.path.join(init_model_path, "dssm.pdiparams")
  82. sub_dirs,file_list = hdfs_client.ls_dir('/dw/recommend/model/56_dssm_i2i_itempredData/20241212')
  83. all_file=[]
  84. file_extensions=[".gz"]
  85. for file in file_list:
  86. # 扩展名过滤
  87. if file_extensions and not any(file.endswith(ext) for ext in file_extensions):
  88. continue
  89. all_file.append(file)
  90. print(f"File list : {all_file}")
  91. max_workers = 16
  92. chunk_size = len(all_file) // max_workers
  93. remaining = len(all_file) % max_workers
  94. # 分割列表
  95. split_file_list = []
  96. for i in range(max_workers):
  97. # 计算每份的起始和结束索引
  98. start = i * chunk_size + min(i, remaining)
  99. end = start + chunk_size + (1 if i < remaining else 0)
  100. # 添加分割后的子列表
  101. split_file_list.append(all_file[start:end])
  102. future_list = []
  103. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  104. for i, file_list in enumerate(split_file_list):
  105. future_list.append(executor.submit(thread_task, f"thread{i}", file_list, model_file, params_file))
  106. for future in future_list:
  107. future.result()
  108. print("Main program ending")
  109. if __name__ == "__main__":
  110. main()