From fb83cb51480d7af413ff183c93df8e5afa39d9da Mon Sep 17 00:00:00 2001 From: manuel couto pintos Date: Wed, 20 May 2026 16:08:35 +0200 Subject: [PATCH] feat(metric)!: add multiclass and multi-label support to AUROC metric (#4960) AUROC was binary-only and panicked on >2 classes. It now supports binary, multiclass and multi-label classification via a One-vs-Rest decomposition aggregated with a Micro/Macro class reduction. - AurocMetric now uses `ConfusionStatsInput` (like Precision/Recall/ FBetaScore), so it is implicitly adapted for both `ClassificationOutput` and `MultiLabelClassificationOutput`. - Add `RankingMetricConfig` (threshold-free config) + `From< ClassificationMetricConfig>`; AUROC has no decision rule by design. - Public constructors: `binary()`, `multiclass(ClassReduction)`, `multilabel(ClassReduction)`. `new` is now private. - Degenerate classes (no positive/negative pair) yield NaN and are dropped from the Macro mean; all-degenerate -> 0.0 with a warning. - Remove the now-dead `AurocInput` type and its `Adaptor` impl. - Update the metric book accordingly. BREAKING CHANGE: - `AurocInput` removed. Use the built-in `ClassificationOutput` / `MultiLabelClassificationOutput` adaptors, or `ConfusionStatsInput`. - `AurocMetric::new()` (no-arg) removed -> use `AurocMetric::binary()`. - Displayed metric name changed: "AUROC" -> "AUROC [Macro|Micro]". --- burn-book/src/building-blocks/metric.md | 4 +- .../burn-train/src/learner/classification.rs | 10 +- crates/burn-train/src/metric/auroc.rs | 354 ++++++++++++------ 3 files changed, 236 insertions(+), 132 deletions(-) diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md index 0f2757780..2796a48ae 100644 --- a/burn-book/src/building-blocks/metric.md +++ b/burn-book/src/building-blocks/metric.md @@ -51,7 +51,7 @@ adaptor code yourself. - `MultiLabelClassificationOutput`: - Use case: Multi-label classification - Fields: `loss: Tensor`, `output: Tensor`, `targets: Tensor` - - Adapted metrics: HammingScore, Precision\*, Recall\*, FBetaScore\*, Loss + - Adapted metrics: HammingScore, Precision\*, Recall\*, FBetaScore\*, AUROC\*, Loss - `RegressionOutput`: - Use case: Regression tasks - Fields: `loss: Tensor`, `output: Tensor`, `targets: Tensor` @@ -61,7 +61,7 @@ adaptor code yourself. - Fields: `loss: Tensor`, `logits: Tensor`, `predictions: Option>`, `targets: Tensor` - Adapted metrics: Accuracy, TopKAccuracy, Perplexity, CER, WER, Loss -\* Precision, Recall, and FBetaScore all use `ConfusionStatsInput` as its input type so these three +\* Precision, Recall, FBetaScore, and AUROC all use `ConfusionStatsInput` as their input type so these metrics are automatically (implicitly) adapted since `ConfusionStatsInput` is adapted. If your metric isn't already adapted for the appropriate output struct, you can implement `Adaptor` yourself. diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs index bf64d1952..7a4a7be1c 100644 --- a/crates/burn-train/src/learner/classification.rs +++ b/crates/burn-train/src/learner/classification.rs @@ -1,6 +1,6 @@ use crate::metric::{ - AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput, - PerplexityInput, TopKAccuracyInput, processor::ItemLazy, + AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, PerplexityInput, + TopKAccuracyInput, processor::ItemLazy, }; use burn_core::tensor::{Device, Int, Tensor, Transaction}; use burn_flex::FlexDevice; @@ -54,12 +54,6 @@ impl Adaptor for ClassificationOutput { } } -impl Adaptor for ClassificationOutput { - fn adapt(&self) -> AurocInput { - AurocInput::new(self.output.clone(), self.targets.clone()) - } -} - impl Adaptor for ClassificationOutput { fn adapt(&self) -> LossInput { LossInput::new(self.loss.clone()) diff --git a/crates/burn-train/src/metric/auroc.rs b/crates/burn-train/src/metric/auroc.rs index ed6ee13f8..ad1f9b0d6 100644 --- a/crates/burn-train/src/metric/auroc.rs +++ b/crates/burn-train/src/metric/auroc.rs @@ -2,100 +2,136 @@ use core::f64; use super::MetricMetadata; use super::state::{FormatOptions, NumericMetricState}; -use crate::metric::{Metric, MetricName, Numeric, SerializedEntry}; -use burn_core::tensor::{Int, Tensor}; +use crate::metric::{ + ClassReduction, ConfusionStatsInput, Metric, MetricName, Numeric, SerializedEntry, +}; +use burn_core::tensor::{Bool, Tensor}; +use std::sync::Arc; -/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)) for binary classification. +/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also +/// referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)). +/// +/// Supports binary, multiclass and multi-label classification through a +/// One-vs-Rest decomposition, aggregated with the configured +/// [class reduction](ClassReduction). #[derive(Clone)] pub struct AurocMetric { name: MetricName, state: NumericMetricState, -} - -/// The [AUROC metric](AurocMetric) input type. -#[derive(new)] -pub struct AurocInput { - outputs: Tensor<2>, - targets: Tensor<1, Int>, + class_reduction: ClassReduction, } impl Default for AurocMetric { fn default() -> Self { - Self::new() + Self::new(Default::default()) } } impl AurocMetric { - /// Creates the metric. - pub fn new() -> Self { + fn new(class_reduction: ClassReduction) -> Self { + let state = Default::default(); + let name = Arc::new(format!("AUROC [{:?}]", class_reduction)); + Self { - name: MetricName::new("AUROC".to_string()), - state: Default::default(), + state, + class_reduction, + name, } } - fn binary_auroc(&self, probabilities: &Tensor<1>, targets: &Tensor<1, Int>) -> f64 { - let n = targets.dims()[0]; + /// AUROC metric for binary classification. + #[allow(dead_code)] + pub fn binary() -> Self { + Self::new(ClassReduction::default()) + } - let n_pos = targets.clone().sum().into_scalar::() as usize; + /// AUROC metric for multiclass classification. + /// + /// # Arguments + /// + /// * `class_reduction` - [Class reduction](ClassReduction) type. + #[allow(dead_code)] + pub fn multiclass(class_reduction: ClassReduction) -> Self { + Self::new(class_reduction) + } - // Early return if we don't have both positive and negative samples - if n_pos == 0 || n_pos == n { - if n_pos == 0 { - log::warn!("Metric cannot be computed because all target values are negative.") - } else { - log::warn!("Metric cannot be computed because all target values are positive.") - } - return 0.0; - } + /// AUROC metric for multi-label classification. + /// + /// # Arguments + /// + /// * `class_reduction` - [Class reduction](ClassReduction) type. + #[allow(dead_code)] + pub fn multilabel(class_reduction: ClassReduction) -> Self { + Self::new(class_reduction) + } - let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]); - let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]); + fn pairwise_auc(scores: Tensor<2>, targets: Tensor<2>) -> Tensor<1> { + let [n, c] = scores.dims(); - let valid_pairs = pos_mask * neg_mask; + let si = scores.clone().reshape([n, 1, c]); + let sj = scores.reshape([1, n, c]); - let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n); - let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n); + let yi = targets.clone().reshape([n, 1, c]); + let yj = targets.reshape([1, n, c]); - let correct_order = prob_i.clone().greater(prob_j.clone()).int(); + let valid: Tensor<3> = yi * (1.0 - yj); - let ties = prob_i.equal(prob_j).int(); + let reduce = |t: Tensor<3>| t.sum_dim(0).sum_dim(1).squeeze_dims::<1>(&[0, 1]); - // Calculate AUC components - let num_pairs = valid_pairs.clone().sum().into_scalar::(); - let correct_pairs = (correct_order * valid_pairs.clone()) - .sum() - .into_scalar::(); - let tied_pairs = (ties * valid_pairs).sum().into_scalar::(); + let num_pairs = reduce(valid.clone()); + let correct_pairs = reduce(si.clone().greater(sj.clone()).float() * valid.clone()); + let tied_pairs = reduce(si.equal(sj).float() * valid); (correct_pairs + 0.5 * tied_pairs) / num_pairs } + + fn compute_auc(&self, predictions: &Tensor<2>, targets: &Tensor<2, Bool>) -> f64 { + let [n, c] = predictions.dims(); + + let (scores, targets) = match self.class_reduction { + ClassReduction::Macro => (predictions.clone(), targets.clone().float()), + ClassReduction::Micro => ( + predictions.clone().reshape([n * c, 1]), + targets.clone().float().reshape([n * c, 1]), + ), + }; + + let auc = Self::pairwise_auc(scores, targets); + + let keep = auc + .clone() + .is_nan() + .bool_not() + .argwhere() + .squeeze_dim::<1>(1); + + if keep.dims()[0] == 0 { + log::warn!( + "AUROC is undefined (no class has both positive and negative samples in the \ + batch); reporting 0.5 (chance level)." + ); + return 0.5; + } + + auc.select(0, keep).mean().into_scalar() + } } impl Metric for AurocMetric { - type Input = AurocInput; + type Input = ConfusionStatsInput; - fn update(&mut self, input: &AurocInput, _metadata: &MetricMetadata) -> SerializedEntry { - let [batch_size, num_classes] = input.outputs.dims(); + fn update( + &mut self, + input: &ConfusionStatsInput, + _metadata: &MetricMetadata, + ) -> SerializedEntry { + let [sample_size, _] = input.predictions.dims(); - assert_eq!( - num_classes, 2, - "Currently only binary classification is supported" - ); - - let probabilities = { - let exponents = input.outputs.clone().exp(); - let sum = exponents.clone().sum_dim(1); - (exponents / sum) - .select(1, Tensor::arange(1..2, &input.outputs.device())) - .squeeze_dim(1) - }; - - let area_under_curve = self.binary_auroc(&probabilities, &input.targets); + let metric = self.compute_auc(&input.predictions, &input.targets); self.state.update( - 100.0 * area_under_curve, - batch_size, + 100.0 * metric, + sample_size, FormatOptions::new(self.name()).unit("%").precision(2), ) } @@ -122,59 +158,127 @@ impl Numeric for AurocMetric { #[cfg(test)] mod tests { use super::*; + use crate::metric::ClassReduction::{self, *}; + use burn_core::tensor::{TensorData, Tolerance}; + use rstest::rstest; - #[test] - fn test_auroc() { - let device = Default::default(); - let mut metric = AurocMetric::new(); + /// Inputs and expected AUROC computed with an independent reference + /// equivalent to scikit-learn's `roc_auc_score` (Mann-Whitney U: + /// `(#pos>neg + 0.5·ties) / (P·N)`, One-vs-Rest, macro/micro). Scores + /// are distinct so the statistic is unambiguous and matches sklearn. + #[derive(Clone, Copy)] + enum Data { + Binary, + Multiclass, + Multilabel, + } - let input = AurocInput::new( - Tensor::from_data( - [ - [0.1, 0.9], // High confidence positive - [0.7, 0.3], // Low confidence negative - [0.6, 0.4], // Low confidence negative - [0.2, 0.8], // High confidence positive - ], - &device, + fn input(data: Data) -> ConfusionStatsInput { + let dev = Default::default(); + match data { + Data::Binary => ConfusionStatsInput::new( + Tensor::from_data([[0.34], [0.64], [0.12], [0.19], [0.53], [0.38]], &dev), + Tensor::from_data([[0], [0], [0], [0], [1], [1]], &dev), ), - Tensor::from_data([1, 0, 0, 1], &device), // True labels + Data::Multiclass => ConfusionStatsInput::new( + Tensor::from_data( + [ + [0.79, 0.41, 0.16], + [0.25, 0.93, 0.78], + [0.61, 0.09, 0.21], + [0.9, 0.31, 0.33], + [0.16, 0.82, 0.57], + [0.57, 0.18, 0.63], + ], + &dev, + ), + Tensor::from_data( + [ + [1, 0, 0], + [1, 0, 0], + [1, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + ], + &dev, + ), + ), + Data::Multilabel => ConfusionStatsInput::new( + Tensor::from_data( + [ + [0.11, 0.57, 0.9], + [0.13, 0.66, 0.37], + [0.71, 0.85, 0.6], + [0.29, 0.69, 0.49], + [0.68, 0.45, 0.25], + [0.33, 0.36, 0.31], + ], + &dev, + ), + Tensor::from_data( + [ + [1, 1, 1], + [0, 0, 1], + [0, 1, 0], + [1, 1, 0], + [0, 1, 1], + [1, 1, 1], + ], + &dev, + ), + ), + } + } + + #[rstest] + // Binary is a single column -> Macro == Micro. + #[case::binary_macro(Data::Binary, Macro, 0.75)] + #[case::binary_micro(Data::Binary, Micro, 0.75)] + #[case::multiclass_macro(Data::Multiclass, Macro, 0.5666666666666667)] + #[case::multiclass_micro(Data::Multiclass, Micro, 0.6458333333333333)] + #[case::multilabel_macro(Data::Multilabel, Macro, 0.2907407407407407)] + #[case::multilabel_micro(Data::Multilabel, Micro, 0.3611111111111111)] + fn test_auroc( + #[case] data: Data, + #[case] class_reduction: ClassReduction, + #[case] expected: f64, + ) { + let mut metric = AurocMetric::new(class_reduction); + + let _entry = metric.update(&input(data), &MetricMetadata::fake()); + + TensorData::from([metric.value().current()]) + .assert_approx_eq::(&TensorData::from([expected * 100.0]), Tolerance::default()); + } + + #[rstest] + #[case::macro_reduction(Macro)] + #[case::micro_reduction(Micro)] + fn test_auroc_perfect_separation(#[case] class_reduction: ClassReduction) { + let device = Default::default(); + let mut metric = AurocMetric::new(class_reduction); + + let input = ConfusionStatsInput::new( + Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device), + Tensor::from_data([[0, 1], [1, 0], [1, 0], [0, 1]], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); assert_eq!(metric.value().current(), 100.0); } - #[test] - fn test_auroc_perfect_separation() { + #[rstest] + #[case::macro_reduction(Macro)] + #[case::micro_reduction(Micro)] + fn test_auroc_chance_level(#[case] class_reduction: ClassReduction) { let device = Default::default(); - let mut metric = AurocMetric::new(); + let mut metric = AurocMetric::new(class_reduction); - let input = AurocInput::new( - Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device), - Tensor::from_data([1, 0, 0, 1], &device), - ); - - let _entry = metric.update(&input, &MetricMetadata::fake()); - assert_eq!(metric.value().current(), 100.0); // Perfect AUC - } - - #[test] - fn test_auroc_random() { - let device = Default::default(); - let mut metric = AurocMetric::new(); - - let input = AurocInput::new( - Tensor::from_data( - [ - [0.5, 0.5], // Random predictions - [0.5, 0.5], - [0.5, 0.5], - [0.5, 0.5], - ], - &device, - ), - Tensor::from_data([1, 0, 0, 1], &device), + // All scores tied -> every pair is a tie -> AUROC = 0.5. + let input = ConfusionStatsInput::new( + Tensor::from_data([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], &device), + Tensor::from_data([[0, 1], [1, 0], [1, 0], [0, 1]], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); @@ -182,44 +286,50 @@ mod tests { } #[test] - fn test_auroc_all_one_class() { + fn test_auroc_macro_drops_degenerate_class() { let device = Default::default(); - let mut metric = AurocMetric::new(); + let mut metric = AurocMetric::new(Macro); - let input = AurocInput::new( + // Class 2 never appears (column all-negative) -> its AUROC is undefined + // and must be dropped, leaving the two well-separated classes at 1.0. + let input = ConfusionStatsInput::new( Tensor::from_data( [ - [0.1, 0.9], // All positives predictions - [0.2, 0.8], - [0.3, 0.7], - [0.4, 0.6], + [0.9, 0.1, 0.0], + [0.2, 0.8, 0.0], + [0.7, 0.3, 0.0], + [0.1, 0.6, 0.0], ], &device, ), - Tensor::from_data([1, 1, 1, 1], &device), // All positive class + Tensor::from_data([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0]], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); - assert_eq!(metric.value().current(), 0.0); + assert_eq!(metric.value().current(), 100.0); } #[test] - #[should_panic(expected = "Currently only binary classification is supported")] - fn test_auroc_multiclass_error() { + fn test_auroc_all_degenerate_is_chance() { let device = Default::default(); - let mut metric = AurocMetric::new(); + let mut metric = AurocMetric::binary(); - let input = AurocInput::new( - Tensor::from_data( - [ - [0.1, 0.2, 0.7], // More than 2 classes not supported - [0.3, 0.5, 0.2], - ], - &device, - ), - Tensor::from_data([2, 1], &device), + // Only positives -> no valid pair in any column -> undefined -> + // reported as chance level (0.5). + let input = ConfusionStatsInput::new( + Tensor::from_data([[0.9], [0.8], [0.7], [0.6]], &device), + Tensor::from_data([[1], [1], [1], [1]], &device), ); let _entry = metric.update(&input, &MetricMetadata::fake()); + assert_eq!(metric.value().current(), 50.0); + } + + #[test] + fn test_auroc_reduction_changes_name() { + let macro_metric = AurocMetric::new(Macro); + let micro_metric = AurocMetric::new(Micro); + + assert_ne!(macro_metric.name(), micro_metric.name()); } }