|
@@ -12,27 +12,45 @@ class DSSMReader(MultiSlotDataGenerator):
|
|
|
def line_process(self, line):
|
|
|
try:
|
|
|
# 按tab分割样本的各个字段
|
|
|
- sample_id, label, left_features, right_features = line.rstrip('\n').split(' ')
|
|
|
-
|
|
|
- # 转换label为整数
|
|
|
- label = int(label)
|
|
|
-
|
|
|
- # 处理左右视频特征
|
|
|
- left_features = [float(x) for x in left_features.split(',')]
|
|
|
- right_features = [float(x) for x in right_features.split(',')]
|
|
|
-
|
|
|
- # 验证特征维度
|
|
|
- if len(left_features) != self.feature_dim or len(right_features) != self.feature_dim:
|
|
|
- return None
|
|
|
-
|
|
|
- # 构建输出列表
|
|
|
- output = []
|
|
|
- #output.append(("sample_id", [sample_id])) # 样本ID
|
|
|
- output.append(("label", [label])) # 标签
|
|
|
- output.append(("left_features", left_features)) # 左视频特征
|
|
|
- output.append(("right_features", right_features)) # 右视频特征
|
|
|
-
|
|
|
- return output
|
|
|
+ sample_values = line.rstrip('\n').split(' ')
|
|
|
+ if len(sample_values) == 4: # 训练格式
|
|
|
+ sample_id, label, left_features, right_features = sample_values
|
|
|
+ # 转换label为整数
|
|
|
+ label = int(label)
|
|
|
+
|
|
|
+ # 处理左右视频特征
|
|
|
+ left_features = [float(x) for x in left_features.split(',')]
|
|
|
+ right_features = [float(x) for x in right_features.split(',')]
|
|
|
+
|
|
|
+ # 验证特征维度
|
|
|
+ if len(left_features) != self.feature_dim or len(right_features) != self.feature_dim:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 构建输出列表
|
|
|
+ output = []
|
|
|
+ #output.append(("sample_id", [sample_id])) # 样本ID
|
|
|
+ output.append(("label", [label])) # 标签
|
|
|
+ output.append(("left_features", left_features)) # 左视频特征
|
|
|
+ output.append(("right_features", right_features)) # 右视频特征
|
|
|
+
|
|
|
+ return output
|
|
|
+ else: #测试格式
|
|
|
+ sample_id, left_features = sample_values
|
|
|
+
|
|
|
+ # 处理左右视频特征
|
|
|
+ left_features = [float(x) for x in left_features.split(',')]
|
|
|
+
|
|
|
+ # 验证特征维度
|
|
|
+ if len(left_features) != self.feature_dim :
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 构建输出列表
|
|
|
+ output = []
|
|
|
+ #output.append(("sample_id", [sample_id])) # 样本ID
|
|
|
+ output.append(("left_features", left_features)) # 左视频特征
|
|
|
+
|
|
|
+
|
|
|
+ return output
|
|
|
|
|
|
except Exception as e:
|
|
|
sys.stderr.write(f"Error processing line: {str(e)}\n")
|