丁云鹏 5 hónapja
szülő
commit
6e2aed0589

+ 18 - 0
recommend-model-produce/src/main/python/lib/brpc_flags.cc

@@ -0,0 +1,18 @@
+#include <gflags/gflags.h>
+#include <pybind11/pybind11.h>
+
+namespace brpc {
+    DECLARE_uint64(max_body_size);
+}
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(brpc_flags, m) {
+    m.def("get_max_body_size", []() {
+        return brpc::FLAGS_max_body_size;
+    }, "A function that returns the max body size.");
+
+    m.def("set_max_body_size", [](int64_t new_size) {
+        brpc::FLAGS_max_body_size = new_size;
+    }, "A function that sets the max body size.");
+}

+ 35 - 38
recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

@@ -38,10 +38,7 @@ import utils.compress as compress
 
 
 sys.path.append(os.path.abspath("") + os.sep  +  "lib")
-print(os.path.abspath("") + os.sep  +  "lib1")
 import brpc_flags
-print(os.path.abspath("") + os.sep  +  "lib2")
-
 
 __dir__ = os.path.dirname(os.path.abspath(__file__))
 sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
@@ -245,41 +242,41 @@ class Main(object):
                     [feed.name for feed in self.inference_feed_var],
                     [self.inference_target_var], self.exe)
 
-
-            if fleet.is_first_worker():
-                # trans to new format
-                # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
-                program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
-                    os.path.join(model_dir, "dnn_plugin"),
-                    self.exe,
-                    model_filename=None,
-                    params_filename=None)
-
-                feedvars = []
-                fetchvars = fetch_targets
-
-                for var_name in feed_target_names:
-                    if var_name not in program.global_block().vars:
-                        raise ValueError(
-                            "Feed variable: {} not in default_main_program, global block has follow vars: {}".
-                            format(var_name,
-                                   program.global_block().vars.keys()))
-                    else:
-                        feedvars.append(program.global_block().vars[var_name])
-
-                paddle.static.save_inference_model(
-                    os.path.join(model_dir, "dnn_plugin_new/dssm"),
-                    feedvars,
-                    fetchvars, 
-                    self.exe,
-                    program = program)
-
-                compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
-                client = HangZhouOSSClient("art-recommend")
-                client.put_object_from_file("dyp/dssm.tar.gz", "dnn_plugin_new.tar.gz")
-                while True:
-                    time.sleep(300)
-                    continue;
+        if fleet.is_first_worker():
+            model_dir = "{}/{}".format(save_model_path, epochs)
+            oss_object_name = self.config.get("runner.oss_object_name")
+            # trans to new format
+            # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
+            program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
+                os.path.join(model_dir, "dnn_plugin"),
+                self.exe,
+                model_filename=None,
+                params_filename=None)
+
+            feedvars = []
+            fetchvars = fetch_targets
+
+            for var_name in feed_target_names:
+                if var_name not in program.global_block().vars:
+                    raise ValueError(
+                        "Feed variable: {} not in default_main_program, global block has follow vars: {}".
+                        format(var_name,
+                               program.global_block().vars.keys()))
+                else:
+                    feedvars.append(program.global_block().vars[var_name])
+            paddle.static.save_inference_model(
+                os.path.join(model_dir, "dnn_plugin_new/dssm"),
+                feedvars,
+                fetchvars, 
+                self.exe,
+                program = program)
+
+            compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
+            client = HangZhouOSSClient("art-recommend")
+            client.put_object_from_file(oss_object_name, "dnn_plugin_new.tar.gz")
+            while True:
+                time.sleep(300)
+                continue;
 
         if reader_type == "InmemoryDataset":
             self.reader.release_memory()

+ 3 - 3
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/DSSMModel.java

@@ -52,11 +52,11 @@ public class DSSMModel implements Model {
 
     @Override
     public boolean loadFromStream(InputStream in) throws Exception {
-        String modelDir = PropertiesUtil.getString("model.dir") + "/demo";
+        String modelDir = PropertiesUtil.getString("model.dir") + "/dssm";
         CompressUtil.decompressGzFile(in, modelDir);
 
-        String modelFile = modelDir + "/inference.pdmodel";
-        String paramFile = modelDir + "/inference.pdiparams";
+        String modelFile = modelDir + "/dssm.pdmodel";
+        String paramFile = modelDir + "/dssm.pdiparams";
 
         Config config = new Config();
         config.setCppModel(modelFile, paramFile);

+ 1 - 1
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/ModelEnum.java

@@ -3,7 +3,7 @@ package com.tzld.piaoquan.recommend.model.service.model;
 import org.apache.commons.lang3.StringUtils;
 
 public enum ModelEnum {
-    VIDEO_DSSM("videoDssm", "", DSSMModel.class),
+    VIDEO_DSSM("videoDssm", "dyp/dssm.tar.gz", DSSMModel.class),
     DEMO("demo", "zhangbo/model_paddle_demo.tar.gz", DemoModel.class),
     DNN("dnn", "dyp/dnn.tar.gz", DNNModel.class),
     NULL("null", "null", null);