Ver código fonte

add infer code

often 5 meses atrás
pai
commit
fb31486edd

+ 14 - 10
recommend-model-produce/src/main/python/models/dssm/bq_reader_infer.py

@@ -28,14 +28,18 @@ class RecDataset(IterableDataset):
         for file in self.file_list:
             with open(file, "r") as rf:
                 for line in rf:
-                    output_list = []
-                    features = line.rstrip('\n').split('\t')
-                    query = [
-                        float(feature) for feature in features[0].split(',')
-                    ]
-                    output_list.append(np.array(query).astype('float32'))
-                    pos_doc = [
-                        float(feature) for feature in features[1].split(',')
-                    ]
-                    output_list.append(np.array(pos_doc).astype('float32'))
+                    sample_values = line.rstrip('\n').split('    ')
+                    sample_id,  left_features = sample_values
+                    # 处理左右视频特征
+                    left_features = [float(x) for x in left_features.split(',')]
+                    # 验证特征维度
+                    if len(left_features) != self.feature_dim :
+                        return None
+                    
+                    # 构建输出列表
+                    output = []
+                    output.append(("sample_id", [sample_id]))  # 样本ID
+                    output.append(("left_features", left_features))   # 左视频特征
+
+ 
                     yield output_list

+ 3 - 3
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -177,10 +177,10 @@ class Main(object):
 
     def network(self):
         self.model = get_model(self.config)
-        self.inference_feed_var = self.model.create_feeds(is_infer=True)
-        # self.inference_feed_var = self.model.create_feeds()
+        self.input_data = self.model.create_feeds(is_infer=True)
+        self.inference_feed_var = self.input_data
         self.init_reader()
-        self.metrics = self.model.net(self.inference_feed_var)
+        self.metrics = self.model.net(self.inference_feed_var,is_infer=True)
         self.inference_target_var = self.model.inference_target_var
         logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
         self.model.create_optimizer(get_strategy(self.config))