inferv2.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. import sys
  3. import numpy as np
  4. import json
  5. from concurrent.futures import ThreadPoolExecutor
  6. from utils.oss_client import HangZhouOSSClient
  7. import utils.compress as compress
  8. from utils.my_hdfs_client import MyHDFSClient
  9. import paddle.inference as paddle_infer
  10. # Hadoop 安装目录和配置信息
  11. hadoop_home = "/app/env/hadoop-3.2.4"
  12. configs = {
  13. "fs.defaultFS": "hdfs://192.168.141.208:9000",
  14. "hadoop.job.ugi": ""
  15. }
  16. hdfs_client = MyHDFSClient(hadoop_home, configs)
  17. def download_and_extract_model(init_model_path, oss_client, oss_object_name):
  18. """下载并解压模型"""
  19. model_tar_path = "model.tar.gz"
  20. oss_client.get_object_to_file(oss_object_name, model_tar_path)
  21. compress.uncompress_tar(model_tar_path, init_model_path)
  22. assert os.path.exists(init_model_path)
  23. def create_paddle_predictor(model_file, params_file):
  24. """创建PaddlePaddle的predictor"""
  25. config = paddle_infer.Config(model_file, params_file)
  26. predictor = paddle_infer.create_predictor(config)
  27. return predictor
  28. def process_file(file_path, model_file, params_file):
  29. """处理单个文件"""
  30. predictor = create_paddle_predictor(model_file, params_file)
  31. ret, out = hdfs_client._run_cmd(f"text {file_path}")
  32. input_data = {}
  33. for line in out:
  34. sample_values = line.rstrip('\n').split('\t')
  35. vid, left_features_str = sample_values
  36. left_features = [float(x) for x in left_features_str.split(',')]
  37. input_data[vid] = left_features
  38. i=0
  39. count=len(input_data)
  40. result = []
  41. for k, v in input_data.items():
  42. v2 = np.array([v], dtype=np.float32)
  43. input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
  44. input_handle.copy_from_cpu(v2)
  45. predictor.run()
  46. output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
  47. output_data = output_handle.copy_to_cpu()
  48. result.append(k + "\t" + str(output_data.tolist()[0]))
  49. i=i+1
  50. if i % 1000 == 0:
  51. print(f"Thread {name}: write batch {i}/{count}")
  52. return result
  53. def write_results(results, output_file):
  54. """将结果写入文件"""
  55. with open(output_file, 'w') as json_file:
  56. for s in results:
  57. json_file.write(s + "\n")
  58. def thread_task(name, file_list, model_file, params_file):
  59. """线程任务"""
  60. print(f"Thread {name}: starting file_list:{file_list}")
  61. results = []
  62. for file_path in file_list:
  63. results.extend(process_file(file_path, model_file, params_file))
  64. output_file = f"/app/data_{os.path.basename(file_path)}.json"
  65. write_results(results, output_file)
  66. print(f"Thread {name}: finishing")
  67. def main():
  68. init_model_path = "/app/output_model_dssm"
  69. client = HangZhouOSSClient("art-recommend")
  70. oss_object_name = "dyp/dssm.tar.gz"
  71. download_and_extract_model(init_model_path, client, oss_object_name)
  72. model_file = os.path.join(init_model_path, "dssm.pdmodel")
  73. params_file = os.path.join(init_model_path, "dssm.pdiparams")
  74. max_workers = 2
  75. split_file_list = [
  76. ['/dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00017.gz'],
  77. ['/dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00018.gz']
  78. ]
  79. future_list = []
  80. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  81. for i, file_list in enumerate(split_file_list):
  82. future_list.append(executor.submit(thread_task, f"thread{i}", file_list, model_file, params_file))
  83. for future in future_list:
  84. future.result()
  85. print("Main program ending")
  86. if __name__ == "__main__":
  87. main()