丁云鹏 5 months ago
parent
commit
50de4174df

BIN
recommend-model-produce/src/main/python/lib/brpc_flags.so


+ 21 - 11
recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

@@ -35,10 +35,15 @@ import struct
 from utils.utils_single import auc
 from utils.oss_client import HangZhouOSSClient
 import utils.compress as compress
+from ctypes import cdll
+
 
 __dir__ = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
 
+cdll.LoadLibrary("/app/lib/brpc_flags.so")
+import brpc_flags
+
 root_loger = logging.getLogger()
 for handler in root_loger.handlers[:]:
     root_loger.removeHandler(handler)
@@ -148,7 +153,6 @@ class Main(object):
     def run_server(self):
         logger.info("Run Server Begin")
         fleet.init_server(config.get("runner.warmup_model_path"))
-        print('global_flags2: ' + str(list(paddle.base.framework._global_flags().keys())))
         fleet.run_server()
 
     def run_worker(self):
@@ -334,20 +338,26 @@ class Main(object):
 
 if __name__ == "__main__":
 
+    print("get_max_body_size1")
+    print(brpc_flags.get_max_body_size())
+    brpc_flags.set_max_body_size(123456789)
+    print("get_max_body_size2")
+    print(brpc_flags.get_max_body_size())
+
     paddle.enable_static()
 
-    read_env_flags = [
-        key[len("FLAGS_") :]
-        for key in core.globals().keys()
-        if key.startswith("FLAGS_")
-    ]
+    # read_env_flags = [
+    #     key[len("FLAGS_") :]
+    #     for key in core.globals().keys()
+    #     if key.startswith("FLAGS_")
+    # ]
 
-    def remove_flag_if_exists(name):
-        if name in read_env_flags:
-            read_env_flags.remove(name)
+    # def remove_flag_if_exists(name):
+    #     if name in read_env_flags:
+    #         read_env_flags.remove(name)
 
-    core.init_gflags(["--tryfromenv=,max_body_size" + ",".join(read_env_flags)])
-    print('global_flags: ' + str(list(paddle.base.framework._global_flags().keys())))
+    # # core.init_gflags(["--tryfromenv=,max_body_size" + ",".join(read_env_flags)])
+    # print('global_flags: ' + str(list(paddle.base.framework._global_flags().keys())))
 
 
     config = parse_args()