丁云鹏 5 meses atrás
pai
commit
9344b22c49

+ 2 - 6
recommend-model-produce/src/main/python/tools/static_trainer.py

@@ -225,10 +225,6 @@ def main(args):
                 "runner.save_inference_fetch_varnames", [])
             fetchvars = []
 
-            for op in paddle.static.default_main_program().global_block().ops:
-                logger.info("op.input_arg_names {}".format(op.input_arg_names))
-            
-
             for var_name in feed_var_names:
                 if var_name not in paddle.static.default_main_program(
                 ).global_block().vars:
@@ -255,10 +251,10 @@ def main(args):
             inference_model_path = save_inference_model(model_save_path, epoch_id, feedvars,
                                  fetchvars, exe)
             if(upload_oss):
-                compress.compress_tar(inference_model_path, "model.tar.gz")
+                compress.compress_tar(model_save_path, "model.tar.gz")
                 client = HangZhouOSSClient("art-recommend")
                 client.put_object_from_file(oss_object_name, "model.tar.gz")
-                logger.info("file {} upload success".format(inference_model_path + ".tar.gz"))
+                logger.info("file {} upload success".format(model_save_path + ".tar.gz"))
 
 
 def dataset_train(epoch_id, dataset, fetch_vars, exe, config):

+ 1 - 56
recommend-model-produce/src/main/python/tools/utils/save_load.py

@@ -76,66 +76,11 @@ def save_inference_model(model_path,
     model_path = os.path.join(model_path, str(epoch_id))
     _mkdir_if_not_exist(model_path)
     model_prefix = os.path.join(model_path, prefix)
-    fetch_vars2 = fetch_vars
-
-    program = _get_valid_program(None)
-
-    # fix the bug that the activation op's output as target will be pruned.
-    # will affect the inference performance.
-    # TODO(Superjomn) add an IR pass to remove 1-scale op.
-    with program_guard(program):
-        uniq_fetch_vars = []
-        for i, var in enumerate(fetch_vars):
-            if var.dtype != paddle.bool:
-                var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}")
-            uniq_fetch_vars.append(var)
-        fetch_vars = uniq_fetch_vars
-
-    copy_program = program.clone()
-    logger.info("program.clone(): {}".format(copy_program.global_block().vars.keys()))
-
-    global_block = copy_program.global_block()
-    remove_op_idx = []
-    for i, op in enumerate(global_block.ops):
-        op.desc.set_is_target(False)
-        if op.type == "feed" or op.type == "fetch":
-            remove_op_idx.append(i)
-
-        if op.type == "pylayer":
-            sub_blocks_ids = op._blocks_attr_ids("blocks")
-            if len(sub_blocks_ids) > 1:
-                # pylayer op ``blocks`` attr contains forward block id and backward block id
-                backward_block_id = sub_blocks_ids[-1]
-                # remove backward block
-                copy_program.blocks.pop(backward_block_id)
-                # update attrs ``blocks``
-                reserverd_blocks = []
-                for block_id in sub_blocks_ids[:-1]:
-                    reserverd_blocks.append(copy_program.block(block_id))
-                op._update_desc_attr("blocks", reserverd_blocks)
-
-    for idx in remove_op_idx[::-1]:
-        global_block._remove_op(idx)
-    copy_program.desc.flush()
-    logger.info("copy_program.desc.flush(): {}".format(copy_program.global_block().vars.keys()))
-
-    feed_var_names = [var.name for var in feed_vars]
-
-    skip_prune_program = False
-    if not skip_prune_program:
-        copy_program = copy_program._prune_with_input(
-            feeded_var_names=feed_var_names, targets=fetch_vars
-        )
-        logger.info("copy_program._prune_with_input(): {}".format(copy_program.global_block().vars.keys()))
-    copy_program = copy_program._inference_optimize(prune_read_op=True)
-    logger.info("copy_program._inference_optimize(): {}".format(copy_program.global_block().vars.keys()))
-    
-
 
     paddle.static.save_inference_model(
         path_prefix=model_prefix,
         feed_vars=feed_vars,
-        fetch_vars=fetch_vars2,
+        fetch_vars=fetch_vars,
         executor=exe)
     return model_prefix