Explorar o código

add infer code

often hai 5 meses
pai
achega
a7a2bc2f77

+ 39 - 21
recommend-model-produce/src/main/python/models/dssm/bq_reader_train_ps.py

@@ -12,27 +12,45 @@ class DSSMReader(MultiSlotDataGenerator):
     def line_process(self, line):
         try:
             # 按tab分割样本的各个字段
-            sample_id, label, left_features, right_features = line.rstrip('\n').split('    ')
-            
-            # 转换label为整数
-            label = int(label)
-            
-            # 处理左右视频特征
-            left_features = [float(x) for x in left_features.split(',')]
-            right_features = [float(x) for x in right_features.split(',')]
-            
-            # 验证特征维度
-            if len(left_features) != self.feature_dim or len(right_features) != self.feature_dim:
-                return None
-            
-            # 构建输出列表
-            output = []
-            #output.append(("sample_id", [sample_id]))  # 样本ID
-            output.append(("label", [label]))          # 标签
-            output.append(("left_features", left_features))   # 左视频特征
-            output.append(("right_features", right_features)) # 右视频特征
-            
-            return output
+            sample_values = line.rstrip('\n').split('    ')
+            if len(sample_values) == 4: # 训练格式
+                sample_id, label, left_features, right_features = sample_values
+                # 转换label为整数
+                label = int(label)
+                
+                # 处理左右视频特征
+                left_features = [float(x) for x in left_features.split(',')]
+                right_features = [float(x) for x in right_features.split(',')]
+                
+                # 验证特征维度
+                if len(left_features) != self.feature_dim or len(right_features) != self.feature_dim:
+                    return None
+                
+                # 构建输出列表
+                output = []
+                #output.append(("sample_id", [sample_id]))  # 样本ID
+                output.append(("label", [label]))          # 标签
+                output.append(("left_features", left_features))   # 左视频特征
+                output.append(("right_features", right_features)) # 右视频特征
+                
+                return output
+            else: #测试格式
+                sample_id,  left_features = sample_values
+
+                # 处理左右视频特征
+                left_features = [float(x) for x in left_features.split(',')]
+                
+                # 验证特征维度
+                if len(left_features) != self.feature_dim :
+                    return None
+                
+                # 构建输出列表
+                output = []
+                #output.append(("sample_id", [sample_id]))  # 样本ID
+                output.append(("left_features", left_features))   # 左视频特征
+
+                
+                return output                
 
         except Exception as e:
             sys.stderr.write(f"Error processing line: {str(e)}\n")

+ 5 - 5
recommend-model-produce/src/main/python/models/dssm/data/test/test.txt

@@ -1,5 +1,5 @@
-djise-19293414-39429345-1789989892    0    1,3,5,2,4    2,16,5,5,8
-agsse-19290414-08429345-1709989892    1    1,3,5,2,4    2,16,5,5,8
-sdfsg-192980914-300345-1789969892    1    1,1,1,1,1    1,1,2,2,1
-gasrew-803414-139429345-1789989892    0    1,3,5,2,4    2,16,5,5,8
-gewt-9293414-429345-1789989852    0    12,3,12,2,4    8,16,9,5,8
+djise-19293414-39429345-1789989892    1,3,0,2,4
+agsse-19290414-08429345-1709989892    1,3,0,2,4
+sdfsg-192980914-300345-1789969892    1,1,1,1,1
+gasrew-803414-139429345-1789989892    1,3,0,2,4
+gewt-9293414-429345-1789989852    2,3,2,2,4