丁云鹏 5 months ago
parent
commit
76537ecb9b
1 changed files with 25 additions and 3 deletions
  1. 25 3
      recommend-model-produce/src/main/python/tools/utils/save_load.py

+ 25 - 3
recommend-model-produce/src/main/python/tools/utils/save_load.py

@@ -16,7 +16,7 @@ import paddle
 import os
 import logging
 import numpy as np
-from paddle.static.io import _get_valid_program, normalize_program
+from paddle.static.io import _get_valid_program, normalize_program, program_guard
 
 logging.basicConfig(
     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
@@ -75,9 +75,31 @@ def save_inference_model(model_path,
     model_path = os.path.join(model_path, str(epoch_id))
     _mkdir_if_not_exist(model_path)
     model_prefix = os.path.join(model_path, prefix)
-
+    fetch_vars2 = fetch_vars
 
     program = _get_valid_program(None)
+
+    for op in program.global_block().ops:
+        # clear device of Op
+        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
+        op._set_attr(device_attr_name, "")
+        if op.type == 'auc':
+            warnings.warn(
+                "Be sure that you have set auc states to 0 before saving inference model."
+            )
+            break
+
+    # fix the bug that the activation op's output as target will be pruned.
+    # will affect the inference performance.
+    # TODO(Superjomn) add an IR pass to remove 1-scale op.
+    with program_guard(program):
+        uniq_fetch_vars = []
+        for i, var in enumerate(fetch_vars):
+            if var.dtype != paddle.bool:
+                var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}")
+            uniq_fetch_vars.append(var)
+        fetch_vars = uniq_fetch_vars
+
     copy_program = program.clone()
     logger.info("program.clone(): {}".format(copy_program.global_block().vars.keys()))
 
@@ -122,7 +144,7 @@ def save_inference_model(model_path,
     paddle.static.save_inference_model(
         path_prefix=model_prefix,
         feed_vars=feed_vars,
-        fetch_vars=fetch_vars,
+        fetch_vars=fetch_vars2,
         executor=exe)
     return model_prefix