feat(metric)!: add multiclass and multi-label support to AUROC metric (#4960)

AUROC was binary-only and panicked on >2 classes. It now supports
binary, multiclass and multi-label classification via a One-vs-Rest
decomposition aggregated with a Micro/Macro class reduction.

- AurocMetric now uses `ConfusionStatsInput` (like Precision/Recall/
  FBetaScore), so it is implicitly adapted for both
  `ClassificationOutput` and `MultiLabelClassificationOutput`.
- Add `RankingMetricConfig` (threshold-free config) + `From<
  ClassificationMetricConfig>`; AUROC has no decision rule by design.
- Public constructors: `binary()`, `multiclass(ClassReduction)`,
  `multilabel(ClassReduction)`. `new` is now private.
- Degenerate classes (no positive/negative pair) yield NaN and are
  dropped from the Macro mean; all-degenerate -> 0.0 with a warning.
- Remove the now-dead `AurocInput` type and its `Adaptor` impl.
- Update the metric book accordingly.

BREAKING CHANGE:
- `AurocInput` removed. Use the built-in `ClassificationOutput` /
  `MultiLabelClassificationOutput` adaptors, or `ConfusionStatsInput`.
- `AurocMetric::new()` (no-arg) removed -> use `AurocMetric::binary()`.
- Displayed metric name changed: "AUROC" -> "AUROC [Macro|Micro]".
This commit is contained in:
manuel couto pintos
2026-05-20 16:08:35 +02:00
committed by GitHub
parent bfb0978107
commit fb83cb5148
3 changed files with 236 additions and 132 deletions

View File

@@ -51,7 +51,7 @@ adaptor code yourself.
- `MultiLabelClassificationOutput<B>`:
- Use case: Multi-label classification
- Fields: `loss: Tensor<B, 1>`, `output: Tensor<B, 2>`, `targets: Tensor<B, 2, Int>`
- Adapted metrics: HammingScore, Precision\*, Recall\*, FBetaScore\*, Loss
- Adapted metrics: HammingScore, Precision\*, Recall\*, FBetaScore\*, AUROC\*, Loss
- `RegressionOutput<B>`:
- Use case: Regression tasks
- Fields: `loss: Tensor<B, 1>`, `output: Tensor<B, 2>`, `targets: Tensor<B, 2>`
@@ -61,7 +61,7 @@ adaptor code yourself.
- Fields: `loss: Tensor<B, 1>`, `logits: Tensor<B, 3>`, `predictions: Option<Tensor<B, 2, Int>>`, `targets: Tensor<B, 2, Int>`
- Adapted metrics: Accuracy, TopKAccuracy, Perplexity, CER, WER, Loss
\* Precision, Recall, and FBetaScore all use `ConfusionStatsInput` as its input type so these three
\* Precision, Recall, FBetaScore, and AUROC all use `ConfusionStatsInput` as their input type so these
metrics are automatically (implicitly) adapted since `ConfusionStatsInput` is adapted.
If your metric isn't already adapted for the appropriate output struct, you can implement `Adaptor` yourself.

View File

@@ -1,6 +1,6 @@
use crate::metric::{
AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput,
PerplexityInput, TopKAccuracyInput, processor::ItemLazy,
AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, PerplexityInput,
TopKAccuracyInput, processor::ItemLazy,
};
use burn_core::tensor::{Device, Int, Tensor, Transaction};
use burn_flex::FlexDevice;
@@ -54,12 +54,6 @@ impl Adaptor<AccuracyInput> for ClassificationOutput {
}
}
impl Adaptor<AurocInput> for ClassificationOutput {
fn adapt(&self) -> AurocInput {
AurocInput::new(self.output.clone(), self.targets.clone())
}
}
impl Adaptor<LossInput> for ClassificationOutput {
fn adapt(&self) -> LossInput {
LossInput::new(self.loss.clone())

View File

@@ -2,100 +2,136 @@ use core::f64;
use super::MetricMetadata;
use super::state::{FormatOptions, NumericMetricState};
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
use burn_core::tensor::{Int, Tensor};
use crate::metric::{
ClassReduction, ConfusionStatsInput, Metric, MetricName, Numeric, SerializedEntry,
};
use burn_core::tensor::{Bool, Tensor};
use std::sync::Arc;
/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)) for binary classification.
/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also
/// referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)).
///
/// Supports binary, multiclass and multi-label classification through a
/// One-vs-Rest decomposition, aggregated with the configured
/// [class reduction](ClassReduction).
#[derive(Clone)]
pub struct AurocMetric {
name: MetricName,
state: NumericMetricState,
}
/// The [AUROC metric](AurocMetric) input type.
#[derive(new)]
pub struct AurocInput {
outputs: Tensor<2>,
targets: Tensor<1, Int>,
class_reduction: ClassReduction,
}
impl Default for AurocMetric {
fn default() -> Self {
Self::new()
Self::new(Default::default())
}
}
impl AurocMetric {
/// Creates the metric.
pub fn new() -> Self {
fn new(class_reduction: ClassReduction) -> Self {
let state = Default::default();
let name = Arc::new(format!("AUROC [{:?}]", class_reduction));
Self {
name: MetricName::new("AUROC".to_string()),
state: Default::default(),
state,
class_reduction,
name,
}
}
fn binary_auroc(&self, probabilities: &Tensor<1>, targets: &Tensor<1, Int>) -> f64 {
let n = targets.dims()[0];
/// AUROC metric for binary classification.
#[allow(dead_code)]
pub fn binary() -> Self {
Self::new(ClassReduction::default())
}
let n_pos = targets.clone().sum().into_scalar::<u64>() as usize;
/// AUROC metric for multiclass classification.
///
/// # Arguments
///
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multiclass(class_reduction: ClassReduction) -> Self {
Self::new(class_reduction)
}
// Early return if we don't have both positive and negative samples
if n_pos == 0 || n_pos == n {
if n_pos == 0 {
log::warn!("Metric cannot be computed because all target values are negative.")
} else {
log::warn!("Metric cannot be computed because all target values are positive.")
}
return 0.0;
}
/// AUROC metric for multi-label classification.
///
/// # Arguments
///
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multilabel(class_reduction: ClassReduction) -> Self {
Self::new(class_reduction)
}
let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]);
let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]);
fn pairwise_auc(scores: Tensor<2>, targets: Tensor<2>) -> Tensor<1> {
let [n, c] = scores.dims();
let valid_pairs = pos_mask * neg_mask;
let si = scores.clone().reshape([n, 1, c]);
let sj = scores.reshape([1, n, c]);
let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n);
let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n);
let yi = targets.clone().reshape([n, 1, c]);
let yj = targets.reshape([1, n, c]);
let correct_order = prob_i.clone().greater(prob_j.clone()).int();
let valid: Tensor<3> = yi * (1.0 - yj);
let ties = prob_i.equal(prob_j).int();
let reduce = |t: Tensor<3>| t.sum_dim(0).sum_dim(1).squeeze_dims::<1>(&[0, 1]);
// Calculate AUC components
let num_pairs = valid_pairs.clone().sum().into_scalar::<f64>();
let correct_pairs = (correct_order * valid_pairs.clone())
.sum()
.into_scalar::<f64>();
let tied_pairs = (ties * valid_pairs).sum().into_scalar::<f64>();
let num_pairs = reduce(valid.clone());
let correct_pairs = reduce(si.clone().greater(sj.clone()).float() * valid.clone());
let tied_pairs = reduce(si.equal(sj).float() * valid);
(correct_pairs + 0.5 * tied_pairs) / num_pairs
}
fn compute_auc(&self, predictions: &Tensor<2>, targets: &Tensor<2, Bool>) -> f64 {
let [n, c] = predictions.dims();
let (scores, targets) = match self.class_reduction {
ClassReduction::Macro => (predictions.clone(), targets.clone().float()),
ClassReduction::Micro => (
predictions.clone().reshape([n * c, 1]),
targets.clone().float().reshape([n * c, 1]),
),
};
let auc = Self::pairwise_auc(scores, targets);
let keep = auc
.clone()
.is_nan()
.bool_not()
.argwhere()
.squeeze_dim::<1>(1);
if keep.dims()[0] == 0 {
log::warn!(
"AUROC is undefined (no class has both positive and negative samples in the \
batch); reporting 0.5 (chance level)."
);
return 0.5;
}
auc.select(0, keep).mean().into_scalar()
}
}
impl Metric for AurocMetric {
type Input = AurocInput;
type Input = ConfusionStatsInput;
fn update(&mut self, input: &AurocInput, _metadata: &MetricMetadata) -> SerializedEntry {
let [batch_size, num_classes] = input.outputs.dims();
fn update(
&mut self,
input: &ConfusionStatsInput,
_metadata: &MetricMetadata,
) -> SerializedEntry {
let [sample_size, _] = input.predictions.dims();
assert_eq!(
num_classes, 2,
"Currently only binary classification is supported"
);
let probabilities = {
let exponents = input.outputs.clone().exp();
let sum = exponents.clone().sum_dim(1);
(exponents / sum)
.select(1, Tensor::arange(1..2, &input.outputs.device()))
.squeeze_dim(1)
};
let area_under_curve = self.binary_auroc(&probabilities, &input.targets);
let metric = self.compute_auc(&input.predictions, &input.targets);
self.state.update(
100.0 * area_under_curve,
batch_size,
100.0 * metric,
sample_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
}
@@ -122,59 +158,127 @@ impl Numeric for AurocMetric {
#[cfg(test)]
mod tests {
use super::*;
use crate::metric::ClassReduction::{self, *};
use burn_core::tensor::{TensorData, Tolerance};
use rstest::rstest;
#[test]
fn test_auroc() {
let device = Default::default();
let mut metric = AurocMetric::new();
/// Inputs and expected AUROC computed with an independent reference
/// equivalent to scikit-learn's `roc_auc_score` (Mann-Whitney U:
/// `(#pos>neg + 0.5·ties) / (P·N)`, One-vs-Rest, macro/micro). Scores
/// are distinct so the statistic is unambiguous and matches sklearn.
#[derive(Clone, Copy)]
enum Data {
Binary,
Multiclass,
Multilabel,
}
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.9], // High confidence positive
[0.7, 0.3], // Low confidence negative
[0.6, 0.4], // Low confidence negative
[0.2, 0.8], // High confidence positive
],
&device,
fn input(data: Data) -> ConfusionStatsInput {
let dev = Default::default();
match data {
Data::Binary => ConfusionStatsInput::new(
Tensor::from_data([[0.34], [0.64], [0.12], [0.19], [0.53], [0.38]], &dev),
Tensor::from_data([[0], [0], [0], [0], [1], [1]], &dev),
),
Tensor::from_data([1, 0, 0, 1], &device), // True labels
Data::Multiclass => ConfusionStatsInput::new(
Tensor::from_data(
[
[0.79, 0.41, 0.16],
[0.25, 0.93, 0.78],
[0.61, 0.09, 0.21],
[0.9, 0.31, 0.33],
[0.16, 0.82, 0.57],
[0.57, 0.18, 0.63],
],
&dev,
),
Tensor::from_data(
[
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
],
&dev,
),
),
Data::Multilabel => ConfusionStatsInput::new(
Tensor::from_data(
[
[0.11, 0.57, 0.9],
[0.13, 0.66, 0.37],
[0.71, 0.85, 0.6],
[0.29, 0.69, 0.49],
[0.68, 0.45, 0.25],
[0.33, 0.36, 0.31],
],
&dev,
),
Tensor::from_data(
[
[1, 1, 1],
[0, 0, 1],
[0, 1, 0],
[1, 1, 0],
[0, 1, 1],
[1, 1, 1],
],
&dev,
),
),
}
}
#[rstest]
// Binary is a single column -> Macro == Micro.
#[case::binary_macro(Data::Binary, Macro, 0.75)]
#[case::binary_micro(Data::Binary, Micro, 0.75)]
#[case::multiclass_macro(Data::Multiclass, Macro, 0.5666666666666667)]
#[case::multiclass_micro(Data::Multiclass, Micro, 0.6458333333333333)]
#[case::multilabel_macro(Data::Multilabel, Macro, 0.2907407407407407)]
#[case::multilabel_micro(Data::Multilabel, Micro, 0.3611111111111111)]
fn test_auroc(
#[case] data: Data,
#[case] class_reduction: ClassReduction,
#[case] expected: f64,
) {
let mut metric = AurocMetric::new(class_reduction);
let _entry = metric.update(&input(data), &MetricMetadata::fake());
TensorData::from([metric.value().current()])
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default());
}
#[rstest]
#[case::macro_reduction(Macro)]
#[case::micro_reduction(Micro)]
fn test_auroc_perfect_separation(#[case] class_reduction: ClassReduction) {
let device = Default::default();
let mut metric = AurocMetric::new(class_reduction);
let input = ConfusionStatsInput::new(
Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
Tensor::from_data([[0, 1], [1, 0], [1, 0], [0, 1]], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 100.0);
}
#[test]
fn test_auroc_perfect_separation() {
#[rstest]
#[case::macro_reduction(Macro)]
#[case::micro_reduction(Micro)]
fn test_auroc_chance_level(#[case] class_reduction: ClassReduction) {
let device = Default::default();
let mut metric = AurocMetric::new();
let mut metric = AurocMetric::new(class_reduction);
let input = AurocInput::new(
Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
Tensor::from_data([1, 0, 0, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 100.0); // Perfect AUC
}
#[test]
fn test_auroc_random() {
let device = Default::default();
let mut metric = AurocMetric::new();
let input = AurocInput::new(
Tensor::from_data(
[
[0.5, 0.5], // Random predictions
[0.5, 0.5],
[0.5, 0.5],
[0.5, 0.5],
],
&device,
),
Tensor::from_data([1, 0, 0, 1], &device),
// All scores tied -> every pair is a tie -> AUROC = 0.5.
let input = ConfusionStatsInput::new(
Tensor::from_data([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], &device),
Tensor::from_data([[0, 1], [1, 0], [1, 0], [0, 1]], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
@@ -182,44 +286,50 @@ mod tests {
}
#[test]
fn test_auroc_all_one_class() {
fn test_auroc_macro_drops_degenerate_class() {
let device = Default::default();
let mut metric = AurocMetric::new();
let mut metric = AurocMetric::new(Macro);
let input = AurocInput::new(
// Class 2 never appears (column all-negative) -> its AUROC is undefined
// and must be dropped, leaving the two well-separated classes at 1.0.
let input = ConfusionStatsInput::new(
Tensor::from_data(
[
[0.1, 0.9], // All positives predictions
[0.2, 0.8],
[0.3, 0.7],
[0.4, 0.6],
[0.9, 0.1, 0.0],
[0.2, 0.8, 0.0],
[0.7, 0.3, 0.0],
[0.1, 0.6, 0.0],
],
&device,
),
Tensor::from_data([1, 1, 1, 1], &device), // All positive class
Tensor::from_data([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0]], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 0.0);
assert_eq!(metric.value().current(), 100.0);
}
#[test]
#[should_panic(expected = "Currently only binary classification is supported")]
fn test_auroc_multiclass_error() {
fn test_auroc_all_degenerate_is_chance() {
let device = Default::default();
let mut metric = AurocMetric::new();
let mut metric = AurocMetric::binary();
let input = AurocInput::new(
Tensor::from_data(
[
[0.1, 0.2, 0.7], // More than 2 classes not supported
[0.3, 0.5, 0.2],
],
&device,
),
Tensor::from_data([2, 1], &device),
// Only positives -> no valid pair in any column -> undefined ->
// reported as chance level (0.5).
let input = ConfusionStatsInput::new(
Tensor::from_data([[0.9], [0.8], [0.7], [0.6]], &device),
Tensor::from_data([[1], [1], [1], [1]], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(metric.value().current(), 50.0);
}
#[test]
fn test_auroc_reduction_changes_name() {
let macro_metric = AurocMetric::new(Macro);
let micro_metric = AurocMetric::new(Micro);
assert_ne!(macro_metric.name(), micro_metric.name());
}
}