luojunhui 4 månader sedan
förälder
incheckning
ce14d39a4d
1 ändrade filer med 6 tillägg och 2 borttagningar
  1. 6 2
      train_model.py

+ 6 - 2
train_model.py

@@ -130,8 +130,12 @@ def calculate_auc(model, dataloader):
         for batch in dataloader:
             inputs = {k: v.to(device) for k, v in batch.items()}
             outputs = model(**inputs)
-            probs = torch.sigmoid(outputs.logits.squeeze()).cpu().numpy()
-            pred_probs.extend(probs)
+
+            # 修改部分:使用Softmax获取概率并提取正类概率
+            probabilities = torch.softmax(outputs.logits, dim=1)  # 对logits做Softmax
+            positive_class_probs = probabilities[:, 1].cpu().numpy()  # 提取第二列(正类)
+
+            pred_probs.extend(positive_class_probs)
             true_labels.extend(inputs['labels'].cpu().numpy())
 
     auc_score = roc_auc_score(true_labels, pred_probs)