浏览代码

dssm train

丁云鹏 4 月之前
父节点
当前提交
77b5445b9d
共有 1 个文件被更改,包括 4 次插入2 次删除
  1. 4 2
      recommend-model-produce/src/main/python/tools/inferv2.py

+ 4 - 2
recommend-model-produce/src/main/python/tools/inferv2.py

@@ -12,6 +12,7 @@ import paddle.inference as paddle_infer
 
 
 logging.basicConfig(
+    filename='vec.log', filemode='w',
     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
 logger = logging.getLogger(__name__)
 
@@ -63,12 +64,13 @@ def main():
     
 
     for k,v in input_data2.items():
-        input_handle.copy_from_cpu(v)
+        v2 = np.array([v], dtype=np.float32)
+        input_handle.copy_from_cpu(v2)
         # 运行predictor
         predictor.run()
         # 获取输出
         output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
-        print("input k:{} v:{}".format(k, fake_input))
+        print("input k:{} v:{}".format(k, v))
         print("Output {}".format(output_data))
 
 if __name__ == "__main__":