test.java 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package com.baidu.paddle.inference;
  2. public class test {
  3. static {
  4. System.out.println(System.getProperty("java.library.path"));
  5. System.out.println(System.mapLibraryName("paddle_inference"));
  6. System.loadLibrary("paddle_inference");
  7. }
  8. public static void main(String[] args) {
  9. Config config = new Config();
  10. config.setCppModel(args[0], args[1]);
  11. config.enableMemoryOptim(true);
  12. config.enableProfile();
  13. config.enableMKLDNN();
  14. System.out.println("summary:\n" + config.summary());
  15. System.out.println("model dir:\n" + config.getCppModelDir());
  16. System.out.println("prog file:\n" + config.getProgFile());
  17. System.out.println("params file:\n" + config.getCppParamsFile());
  18. config.getCpuMathLibraryNumThreads();
  19. config.getFractionOfGpuMemoryForPool();
  20. config.switchIrDebug(false);
  21. System.out.println(config.summary());
  22. Predictor predictor = Predictor.createPaddlePredictor(config);
  23. String inNames = predictor.getInputNameById(0);
  24. Tensor inHandle = predictor.getInputHandle(inNames);
  25. inHandle.reshape(4, new int[]{1, 3, 224, 224});
  26. float[] inData = new float[1 * 3 * 224 * 224];
  27. inHandle.copyFromCpu(inData);
  28. predictor.run();
  29. String outNames = predictor.getOutputNameById(0);
  30. Tensor outHandle = predictor.getOutputHandle(outNames);
  31. float[] outData = new float[outHandle.getSize()];
  32. outHandle.copyToCpu(outData);
  33. predictor.tryShrinkMemory();
  34. predictor.clearIntermediateTensor();
  35. System.out.println("predictor1: " + outData[0]);
  36. System.out.println("predictor1: " + outData.length);
  37. test(predictor);
  38. outHandle.destroyNativeTensor();
  39. inHandle.destroyNativeTensor();
  40. predictor.destroyNativePredictor();
  41. Config newConfig = new Config();
  42. newConfig.setCppModelDir("/model_dir");
  43. newConfig.setCppProgFile("/prog_file");
  44. newConfig.setCppParamsFile("/param");
  45. System.out.println("model dir:\n" + newConfig.getCppModelDir());
  46. System.out.println("prog file:\n" + newConfig.getProgFile());
  47. System.out.println("params file:\n" + newConfig.getCppParamsFile());
  48. config.destroyNativeConfig();
  49. }
  50. private static void test(Predictor predictor) {
  51. Predictor predictor2 = Predictor.clonePaddlePredictor(predictor);
  52. String inNames = predictor.getInputNameById(0);
  53. Tensor inHandle = predictor.getInputHandle(inNames);
  54. inHandle.reshape(4, new int[]{1, 3, 224, 224});
  55. float[] inData = new float[1 * 3 * 224 * 224];
  56. inHandle.copyFromCpu(inData);
  57. predictor.run();
  58. String outNames = predictor.getOutputNameById(0);
  59. Tensor outHandle = predictor.getOutputHandle(outNames);
  60. float[] outData = new float[outHandle.getSize()];
  61. outHandle.copyToCpu(outData);
  62. predictor.tryShrinkMemory();
  63. predictor.clearIntermediateTensor();
  64. System.out.println("predictor2: " + outData[0]);
  65. System.out.println("predictor2: " + outData.length);
  66. outHandle.destroyNativeTensor();
  67. inHandle.destroyNativeTensor();
  68. predictor.destroyNativePredictor();
  69. }
  70. }