丁云鹏 4 месяцев назад
Родитель
Сommit
a74b969dcc
1 измененных файлов с 67 добавлено и 76 удалено
  1. 67 76
      recommend-model-produce/src/main/python/tools/inferv2.py

+ 67 - 76
recommend-model-produce/src/main/python/tools/inferv2.py

@@ -1,105 +1,96 @@
 import os
 import sys
 import numpy as np
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
+import json
+from concurrent.futures import ThreadPoolExecutor
 from utils.oss_client import HangZhouOSSClient
 import utils.compress as compress
 from utils.my_hdfs_client import MyHDFSClient
-# 引用 paddle inference 推理库
 import paddle.inference as paddle_infer
-import json
-from concurrent.futures import ThreadPoolExecutor
 
-hadoop_home = "/app/env/hadoop-3.2.4"  # Hadoop 安装目录
+# Hadoop 安装目录和配置信息
+hadoop_home = "/app/env/hadoop-3.2.4"
 configs = {
-    "fs.default.name": "hdfs://192.168.141.208:9000",  # HDFS 名称和端口
-    "hadoop.job.ugi": ""  # HDFS 用户和密码
+    "fs.defaultFS": "hdfs://192.168.141.208:9000",
+    "hadoop.job.ugi": ""
 }
 hdfs_client = MyHDFSClient(hadoop_home, configs)
 
+def download_and_extract_model(init_model_path, oss_client, oss_object_name):
+    """下载并解压模型"""
+    model_tar_path = "model.tar.gz"
+    oss_client.get_object_to_file(oss_object_name, model_tar_path)
+    compress.uncompress_tar(model_tar_path, init_model_path)
+    assert os.path.exists(init_model_path)
+
+def create_paddle_predictor(model_file, params_file):
+    """创建PaddlePaddle的predictor"""
+    config = paddle_infer.Config(model_file, params_file)
+    predictor = paddle_infer.create_predictor(config)
+    return predictor
+
+def process_file(file_path, model_file, params_file):
+    """处理单个文件"""
+    predictor = create_paddle_predictor(model_file, params_file)
+    ret, out = hdfs_client._run_cmd(f"text {file_path}")
+    input_data = {}
+    for line in out:
+        sample_values = line.rstrip('\n').split('\t')
+        vid, left_features_str = sample_values
+        left_features = [float(x) for x in left_features_str.split(',')]
+        input_data[vid] = left_features
+
+    result = []
+    for k, v in input_data.items():
+        v2 = np.array([v], dtype=np.float32)
+        input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
+        input_handle.copy_from_cpu(v2)
+        predictor.run()
+        output_handle = predictor.get_output_handle(predictor.get_output_names()[0])
+        output_data = output_handle.copy_to_cpu()
+        result.append(k + "\t" + str(output_data.tolist()[0]))
+
+    return result
+
+def write_results(results, output_file):
+    """将结果写入文件"""
+    with open(output_file, 'w') as json_file:
+        for s in results:
+            json_file.write(s + "\n")
+
+def thread_task(name, file_list, model_file, params_file):
+    """线程任务"""
+    print(f"Thread {name}: starting file_list:{file_list}")
+    results = []
+    for file_path in file_list:
+        results.extend(process_file(file_path, model_file, params_file))
+    output_file = f"/app/data_{os.path.basename(file_list[0])}.json"
+    write_results(results, output_file)
+    print(f"Thread {name}: finishing")
 
 def main():
     init_model_path = "/app/output_model_dssm"
     client = HangZhouOSSClient("art-recommend")
     oss_object_name = "dyp/dssm.tar.gz"
-    client.get_object_to_file(oss_object_name, "model.tar.gz")
-    compress.uncompress_tar("model.tar.gz", init_model_path)
-    assert os.path.exists(init_model_path)
+    download_and_extract_model(init_model_path, client, oss_object_name)
 
-    self.model_file=os.path.join(init_model_path, "dssm.pdmodel")
-    self.params_file=os.path.join(init_model_path, "dssm.pdiparams")
+    model_file = os.path.join(init_model_path, "dssm.pdmodel")
+    params_file = os.path.join(init_model_path, "dssm.pdiparams")
 
-    max_workers=2
-    spilt_file_list=[
-    ['/dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00017.gz'],
-    ['/dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00017.gz']
+    max_workers = 2
+    split_file_list = [
+        ['/dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00017.gz'],
+        ['/dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00017.gz']
     ]
-    future_list=[]
+    future_list = []
     with ThreadPoolExecutor(max_workers=max_workers) as executor:
         for i, file_list in enumerate(split_file_list):
-            future_list.append(executor.submit(thread_task, "thread" + str(i), file_list))
-    # 等待所有任务完成
+            future_list.append(executor.submit(thread_task, f"thread{i}", file_list, model_file, params_file))
+
     for future in future_list:
         future.result()
 
     print("Main program ending")
 
-
-def thread_task(name, file_list):
-    print(f"Thread {name}: starting file_list:{file_list}"):
-
-    # 创建 config
-    config = paddle_infer.Config(self.model_file, self.params_file)
-
-    # 根据 config 创建 predictor
-    predictor = paddle_infer.create_predictor(config)
-    # 获取输入的名称
-    input_names = predictor.get_input_names()
-    input_handle = predictor.get_input_handle(input_names[0])
-    output_names = predictor.get_output_names()
-    output_handle = predictor.get_output_handle(output_names[0])
-
-    fi=0
-    file_len = len(file_list)
-    for flie in file_list:
-        ret, out = hdfs_client._run_cmd(f"text {file}")
-        input_data = {}
-        for line in out:
-            sample_values = line.rstrip('\n').split('\t')
-            vid, left_features_str = sample_values
-            left_features = [float(x) for x in left_features_str.split(',')]
-            input_data[vid] = left_features
-            
-
-        # 设置输入
-
-        result = []
-        i=0
-        fi=fi+1
-        count =  len(input_data)
-        print(f"Thread {name}: current handle {fi}/{file_len} file {flie} count {count}")
-
-
-        for k,v in input_data.items():
-            v2 = np.array([v], dtype=np.float32)
-            input_handle.copy_from_cpu(v2)
-            # 运行predictor
-            predictor.run()
-            # 获取输出
-            output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
-            result.append(k + "\t" + str(output_data.tolist()[0]))
-            i=i+1
-            if i % 1000 == 0:
-                print(f"Thread {name}: write batch {i}/{count}")
-
-        json_data = json.dumps(result, indent=4)  # indent参数用于美化输出,使其更易读
-        # 写入文件
-        with open('/app/data_' + os.path.basename(flie) + '.json', 'w') as json_file:
-            for s in result:
-                json_file.write(s + "\n")
-    print(f"Thread {name}: finishing")
-
-
 if __name__ == "__main__":
     main()