Browse Source

dssm train

丁云鹏 4 months ago
parent
commit
77bb0b6a79
1 changed files with 15 additions and 4 deletions
  1. 15 4
      recommend-model-produce/src/main/python/tools/inferv2.py

+ 15 - 4
recommend-model-produce/src/main/python/tools/inferv2.py

@@ -57,6 +57,10 @@ def main():
     # 设置输入
 
     result = {}
+    i=0
+    with open('/app/data.json', 'w') as json_file:
+        json_file.write("")
+
     for k,v in input_data2.items():
         v2 = np.array([v], dtype=np.float32)
         input_handle.copy_from_cpu(v2)
@@ -64,14 +68,21 @@ def main():
         predictor.run()
         # 获取输出
         output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
-        result[k] = output_data
-
+        result[k] = output_data.tolist()[0]
+        i=i+1
+        if i >= 1000:
+            print("write batch")
+            json_data = str(json.dumps(result))  # indent参数用于美化输出,使其更易读
+            # 写入文件
+            with open('/app/data.json', 'a') as json_file:
+                json_file.write(json_data)
+            result={}
+            i=0
 
     json_data = json.dumps(result, indent=4)  # indent参数用于美化输出,使其更易读
     # 写入文件
-    with open('/app/data.json', 'w') as json_file:
+    with open('/app/data.json', 'a') as json_file:
         json_file.write(json_data)
 
-
 if __name__ == "__main__":
     main()