Bladeren bron

dssm train

丁云鹏 4 maanden geleden
bovenliggende
commit
afed16d13f
1 gewijzigde bestanden met toevoegingen van 8 en 12 verwijderingen
  1. 8 12
      recommend-model-produce/src/main/python/tools/inferv2.py

+ 8 - 12
recommend-model-produce/src/main/python/tools/inferv2.py

@@ -55,8 +55,9 @@ def main():
 
     # 设置输入
 
-    result = {}
+    result = []
     i=0
+    count =  len(input_data)
     with open('/app/data.json', 'w') as json_file:
         json_file.write("")
 
@@ -67,21 +68,16 @@ def main():
         predictor.run()
         # 获取输出
         output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
-        result[k] = output_data.tolist()[0]
+        result.append(k + "\t" + output_data.tolist()[0])
         i=i+1
-        if i >= 1000:
-            print("write batch")
-            json_data = str(json.dumps(result))  # indent参数用于美化输出,使其更易读
-            # 写入文件
-            with open('/app/data.json', 'a') as json_file:
-                json_file.write(json_data)
-            result={}
-            i=0
+        if i % 1000 == 0:
+            print("write batch {}/{}".format(i, count))
 
     json_data = json.dumps(result, indent=4)  # indent参数用于美化输出,使其更易读
     # 写入文件
     with open('/app/data.json', 'a') as json_file:
-        json_file.write(json_data)
+        for s in result:
+            json_file.write(s + "\n")
 
 if __name__ == "__main__":
-    main()
+    main()