|
@@ -130,8 +130,12 @@ def calculate_auc(model, dataloader):
|
|
|
for batch in dataloader:
|
|
|
inputs = {k: v.to(device) for k, v in batch.items()}
|
|
|
outputs = model(**inputs)
|
|
|
- probs = torch.sigmoid(outputs.logits.squeeze()).cpu().numpy()
|
|
|
- pred_probs.extend(probs)
|
|
|
+
|
|
|
+ # 修改部分:使用Softmax获取概率并提取正类概率
|
|
|
+ probabilities = torch.softmax(outputs.logits, dim=1) # 对logits做Softmax
|
|
|
+ positive_class_probs = probabilities[:, 1].cpu().numpy() # 提取第二列(正类)
|
|
|
+
|
|
|
+ pred_probs.extend(positive_class_probs)
|
|
|
true_labels.extend(inputs['labels'].cpu().numpy())
|
|
|
|
|
|
auc_score = roc_auc_score(true_labels, pred_probs)
|