丁云鹏 4 bulan lalu
induk
melakukan
e0cb1aaefe
1 mengubah file dengan 62 tambahan dan 38 penghapusan
  1. 62 38
      recommend-model-produce/src/main/python/tools/inferv2.py

+ 62 - 38
recommend-model-produce/src/main/python/tools/inferv2.py

@@ -9,6 +9,7 @@ from utils.my_hdfs_client import MyHDFSClient
 # 引用 paddle inference 推理库
 import paddle.inference as paddle_infer
 import json
+from concurrent.futures import ThreadPoolExecutor
 
 hadoop_home = "/app/env/hadoop-3.2.4"  # Hadoop 安装目录
 configs = {
@@ -18,10 +19,7 @@ configs = {
 hdfs_client = MyHDFSClient(hadoop_home, configs)
 
 
-
-
 def main():
-
     init_model_path = "/app/output_model_dssm"
     client = HangZhouOSSClient("art-recommend")
     oss_object_name = "dyp/dssm.tar.gz"
@@ -29,11 +27,30 @@ def main():
     compress.uncompress_tar("model.tar.gz", init_model_path)
     assert os.path.exists(init_model_path)
 
-    model_file=os.path.join(init_model_path, "dssm.pdmodel")
-    params_file=os.path.join(init_model_path, "dssm.pdiparams")
+    self.model_file=os.path.join(init_model_path, "dssm.pdmodel")
+    self.params_file=os.path.join(init_model_path, "dssm.pdiparams")
+
+    max_workers=2
+    spilt_file_list=[
+    ['/dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00017.gz'],
+    ['/dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00017.gz']
+    ]
+    future_list=[]
+    with ThreadPoolExecutor(max_workers=max_workers) as executor:
+        for i, file_list in enumerate(split_file_list):
+            future_list.append(executor.submit(thread_task, "thread" + str(i), file_list))
+    # 等待所有任务完成
+    for future in future_list:
+        future.result()
+
+    print("Main program ending")
+
+
+def thread_task(name, file_list):
+    print(f"Thread {name}: starting file_list:{file_list}"):
 
     # 创建 config
-    config = paddle_infer.Config(model_file, params_file)
+    config = paddle_infer.Config(self.model_file, self.params_file)
 
     # 根据 config 创建 predictor
     predictor = paddle_infer.create_predictor(config)
@@ -43,39 +60,46 @@ def main():
     output_names = predictor.get_output_names()
     output_handle = predictor.get_output_handle(output_names[0])
 
+    fi=0
+    file_len = len(file_list)
+    for flie in file_list:
+        ret, out = hdfs_client._run_cmd(f"text {file}")
+        input_data = {}
+        for line in out:
+            sample_values = line.rstrip('\n').split('\t')
+            vid, left_features_str = sample_values
+            left_features = [float(x) for x in left_features_str.split(',')]
+            input_data[vid] = left_features
+            
+
+        # 设置输入
+
+        result = []
+        i=0
+        fi=fi+1
+        count =  len(input_data)
+        print(f"Thread {name}: current handle {fi}/{file_len} file {flie} count {count}")
+
+
+        for k,v in input_data.items():
+            v2 = np.array([v], dtype=np.float32)
+            input_handle.copy_from_cpu(v2)
+            # 运行predictor
+            predictor.run()
+            # 获取输出
+            output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
+            result.append(k + "\t" + str(output_data.tolist()[0]))
+            i=i+1
+            if i % 1000 == 0:
+                print(f"Thread {name}: write batch {i}/{count}")
+
+        json_data = json.dumps(result, indent=4)  # indent参数用于美化输出,使其更易读
+        # 写入文件
+        with open('/app/data_' + os.path.basename(flie) + '.json', 'w') as json_file:
+            for s in result:
+                json_file.write(s + "\n")
+    print(f"Thread {name}: finishing")
 
-    ret, out = hdfs_client._run_cmd("text /dw/recommend/model/56_dssm_i2i_itempredData/20241206/part-00016.gz")
-    input_data = {}
-    for line in out:
-        sample_values = line.rstrip('\n').split('\t')
-        vid, left_features_str = sample_values
-        left_features = [float(x) for x in left_features_str.split(',')]
-        input_data[vid] = left_features
-        
-
-    # 设置输入
-
-    result = []
-    i=0
-    count =  len(input_data)
-
-    for k,v in input_data.items():
-        v2 = np.array([v], dtype=np.float32)
-        input_handle.copy_from_cpu(v2)
-        # 运行predictor
-        predictor.run()
-        # 获取输出
-        output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
-        result.append(k + "\t" + str(output_data.tolist()[0]))
-        i=i+1
-        if i % 1000 == 0:
-            print("write batch {}/{}".format(i, count))
-
-    json_data = json.dumps(result, indent=4)  # indent参数用于美化输出,使其更易读
-    # 写入文件
-    with open('/app/data.json', 'w') as json_file:
-        for s in result:
-            json_file.write(s + "\n")
 
 if __name__ == "__main__":
     main()