丁云鹏 5 månader sedan
förälder
incheckning
7ac490d8d2
1 ändrade filer med 40 tillägg och 7 borttagningar
  1. 40 7
      recommend-model-produce/src/main/python/tools/utils/save_load.py

+ 40 - 7
recommend-model-produce/src/main/python/tools/utils/save_load.py

@@ -21,7 +21,7 @@ from paddle.static.io import _get_valid_program, normalize_program
 logging.basicConfig(
     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
 logger = logging.getLogger(__name__)
-
+logger.setLevel(logging.INFO)
 
 def save_model(net, optimizer, model_path, epoch_id, prefix='rec'):
     model_path = os.path.join(model_path, str(epoch_id))
@@ -78,12 +78,45 @@ def save_inference_model(model_path,
 
 
     program = _get_valid_program(None)
-    program = normalize_program(
-        program,
-        feed_vars,
-        fetch_vars,
-        skip_prune_program=True)
-    logger.info("global block has follow vars: {}".format(program.global_block().vars.keys()))
+    copy_program = program.clone()
+    logger.info("program.clone(): {}".format(program.global_block().vars.keys()))
+
+    global_block = copy_program.global_block()
+    remove_op_idx = []
+    for i, op in enumerate(global_block.ops):
+        op.desc.set_is_target(False)
+        if op.type == "feed" or op.type == "fetch":
+            remove_op_idx.append(i)
+
+        if op.type == "pylayer":
+            sub_blocks_ids = op._blocks_attr_ids("blocks")
+            if len(sub_blocks_ids) > 1:
+                # pylayer op ``blocks`` attr contains forward block id and backward block id
+                backward_block_id = sub_blocks_ids[-1]
+                # remove backward block
+                copy_program.blocks.pop(backward_block_id)
+                # update attrs ``blocks``
+                reserverd_blocks = []
+                for block_id in sub_blocks_ids[:-1]:
+                    reserverd_blocks.append(copy_program.block(block_id))
+                op._update_desc_attr("blocks", reserverd_blocks)
+
+    for idx in remove_op_idx[::-1]:
+        global_block._remove_op(idx)
+    copy_program.desc.flush()
+    logger.info("copy_program.desc.flush(): {}".format(program.global_block().vars.keys()))
+
+    feed_var_names = [var.name for var in feed_vars]
+
+    skip_prune_program = False
+    if not skip_prune_program:
+        copy_program = copy_program._prune_with_input(
+            feeded_var_names=feed_var_names, targets=fetch_vars
+        )
+        logger.info("copy_program._prune_with_input(): {}".format(program.global_block().vars.keys()))
+    copy_program = copy_program._inference_optimize(prune_read_op=True)
+    logger.info("copy_program._inference_optimize(): {}".format(program.global_block().vars.keys()))
+    
 
 
     paddle.static.save_inference_model(