often 5 місяців тому
батько
коміт
39fbc47cc9

+ 1 - 1
recommend-model-produce/src/main/python/models/dssm/bq_reader_train_ps.py

@@ -46,7 +46,7 @@ class DSSMReader(MultiSlotDataGenerator):
                 
                 # 构建输出列表
                 output = []
-                #output.append(("sample_id", [sample_id]))  # 样本ID
+                output.append(("sample_id", [sample_id]))  # 样本ID
                 output.append(("left_features", left_features))   # 左视频特征
 
                 

+ 17 - 15
recommend-model-produce/src/main/python/models/dssm/static_model.py

@@ -21,22 +21,25 @@ class StaticModel():
 
     def create_feeds(self, is_infer=False):
         # 定义输入数据占位符
-        # sample_id = paddle.static.data(
-        #    name="sample_id", shape=[-1, 1], dtype='int64')
+
         feeds_list = []
         if not is_infer:
             label = paddle.static.data(
                 name="label", shape=[-1, 1], dtype='float32')
-            feeds_list.append(label)
-                    
-        left_features = paddle.static.data(
-            name="left_features", shape=[-1, self.feature_num], dtype='float32')
-        feeds_list.append(left_features)
-        right_features = paddle.static.data(
-            name="right_features", shape=[-1, self.feature_num], dtype='float32')
-        feeds_list.append(right_features)
-        
-        
+            feeds_list.append(label)     
+            left_features = paddle.static.data(
+                name="left_features", shape=[-1, self.feature_num], dtype='float32')
+            feeds_list.append(left_features)
+            right_features = paddle.static.data(
+                name="right_features", shape=[-1, self.feature_num], dtype='float32')
+            feeds_list.append(right_features)
+        else:
+            sample_id = paddle.static.data(
+                name="sample_id", shape=[-1, 1], dtype='int64')
+            feeds_list.append(label)     
+            left_features = paddle.static.data(
+                name="left_features", shape=[-1, self.feature_num], dtype='float32')
+            feeds_list.append(left_features)
 
 
         return feeds_list
@@ -52,10 +55,11 @@ class StaticModel():
         )
 
         if is_infer:
-            left_features = input
+            sample_id,left_features = input
             left_vec = dssm_model(left_features,None,is_infer=True)
             self.inference_target_var = left_vec
             fetch_dict = {
+                'sample_id': sample_id
                 'left_vector': left_vec
             }
             return fetch_dict
@@ -68,8 +72,6 @@ class StaticModel():
             self.left_vector = left_vec
             self.right_vector = right_vec
 
-
-
             # 计算损失
             # 使用带margin的二元交叉熵损失
             pos_mask = paddle.cast(label > 0.5, 'float32')

+ 3 - 3
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -177,10 +177,10 @@ class Main(object):
 
     def network(self):
         self.model = get_model(self.config)
-        self.input_data = self.model.create_feeds()
-        self.inference_feed_var = self.model.create_feeds()
+        self.inference_feed_var = self.model.create_feeds(is_infer=True)
+        # self.inference_feed_var = self.model.create_feeds()
         self.init_reader()
-        self.metrics = self.model.net(self.input_data)
+        self.metrics = self.model.net(self.inference_feed_var)
         self.inference_target_var = self.model.inference_target_var
         logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
         self.model.create_optimizer(get_strategy(self.config))