mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
feat(metric)!: add multiclass and multi-label support to AUROC metric (#4960)
AUROC was binary-only and panicked on >2 classes. It now supports binary, multiclass and multi-label classification via a One-vs-Rest decomposition aggregated with a Micro/Macro class reduction. - AurocMetric now uses `ConfusionStatsInput` (like Precision/Recall/ FBetaScore), so it is implicitly adapted for both `ClassificationOutput` and `MultiLabelClassificationOutput`. - Add `RankingMetricConfig` (threshold-free config) + `From< ClassificationMetricConfig>`; AUROC has no decision rule by design. - Public constructors: `binary()`, `multiclass(ClassReduction)`, `multilabel(ClassReduction)`. `new` is now private. - Degenerate classes (no positive/negative pair) yield NaN and are dropped from the Macro mean; all-degenerate -> 0.0 with a warning. - Remove the now-dead `AurocInput` type and its `Adaptor` impl. - Update the metric book accordingly. BREAKING CHANGE: - `AurocInput` removed. Use the built-in `ClassificationOutput` / `MultiLabelClassificationOutput` adaptors, or `ConfusionStatsInput`. - `AurocMetric::new()` (no-arg) removed -> use `AurocMetric::binary()`. - Displayed metric name changed: "AUROC" -> "AUROC [Macro|Micro]".
This commit is contained in:
committed by
GitHub
parent
bfb0978107
commit
fb83cb5148
@@ -51,7 +51,7 @@ adaptor code yourself.
|
||||
- `MultiLabelClassificationOutput<B>`:
|
||||
- Use case: Multi-label classification
|
||||
- Fields: `loss: Tensor<B, 1>`, `output: Tensor<B, 2>`, `targets: Tensor<B, 2, Int>`
|
||||
- Adapted metrics: HammingScore, Precision\*, Recall\*, FBetaScore\*, Loss
|
||||
- Adapted metrics: HammingScore, Precision\*, Recall\*, FBetaScore\*, AUROC\*, Loss
|
||||
- `RegressionOutput<B>`:
|
||||
- Use case: Regression tasks
|
||||
- Fields: `loss: Tensor<B, 1>`, `output: Tensor<B, 2>`, `targets: Tensor<B, 2>`
|
||||
@@ -61,7 +61,7 @@ adaptor code yourself.
|
||||
- Fields: `loss: Tensor<B, 1>`, `logits: Tensor<B, 3>`, `predictions: Option<Tensor<B, 2, Int>>`, `targets: Tensor<B, 2, Int>`
|
||||
- Adapted metrics: Accuracy, TopKAccuracy, Perplexity, CER, WER, Loss
|
||||
|
||||
\* Precision, Recall, and FBetaScore all use `ConfusionStatsInput` as its input type so these three
|
||||
\* Precision, Recall, FBetaScore, and AUROC all use `ConfusionStatsInput` as their input type so these
|
||||
metrics are automatically (implicitly) adapted since `ConfusionStatsInput` is adapted.
|
||||
|
||||
If your metric isn't already adapted for the appropriate output struct, you can implement `Adaptor` yourself.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::metric::{
|
||||
AccuracyInput, Adaptor, AurocInput, ConfusionStatsInput, HammingScoreInput, LossInput,
|
||||
PerplexityInput, TopKAccuracyInput, processor::ItemLazy,
|
||||
AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, PerplexityInput,
|
||||
TopKAccuracyInput, processor::ItemLazy,
|
||||
};
|
||||
use burn_core::tensor::{Device, Int, Tensor, Transaction};
|
||||
use burn_flex::FlexDevice;
|
||||
@@ -54,12 +54,6 @@ impl Adaptor<AccuracyInput> for ClassificationOutput {
|
||||
}
|
||||
}
|
||||
|
||||
impl Adaptor<AurocInput> for ClassificationOutput {
|
||||
fn adapt(&self) -> AurocInput {
|
||||
AurocInput::new(self.output.clone(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Adaptor<LossInput> for ClassificationOutput {
|
||||
fn adapt(&self) -> LossInput {
|
||||
LossInput::new(self.loss.clone())
|
||||
|
||||
@@ -2,100 +2,136 @@ use core::f64;
|
||||
|
||||
use super::MetricMetadata;
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use crate::metric::{Metric, MetricName, Numeric, SerializedEntry};
|
||||
use burn_core::tensor::{Int, Tensor};
|
||||
use crate::metric::{
|
||||
ClassReduction, ConfusionStatsInput, Metric, MetricName, Numeric, SerializedEntry,
|
||||
};
|
||||
use burn_core::tensor::{Bool, Tensor};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)) for binary classification.
|
||||
/// The Area Under the Receiver Operating Characteristic Curve (AUROC, also
|
||||
/// referred to as [ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic)).
|
||||
///
|
||||
/// Supports binary, multiclass and multi-label classification through a
|
||||
/// One-vs-Rest decomposition, aggregated with the configured
|
||||
/// [class reduction](ClassReduction).
|
||||
#[derive(Clone)]
|
||||
pub struct AurocMetric {
|
||||
name: MetricName,
|
||||
state: NumericMetricState,
|
||||
}
|
||||
|
||||
/// The [AUROC metric](AurocMetric) input type.
|
||||
#[derive(new)]
|
||||
pub struct AurocInput {
|
||||
outputs: Tensor<2>,
|
||||
targets: Tensor<1, Int>,
|
||||
class_reduction: ClassReduction,
|
||||
}
|
||||
|
||||
impl Default for AurocMetric {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
Self::new(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl AurocMetric {
|
||||
/// Creates the metric.
|
||||
pub fn new() -> Self {
|
||||
fn new(class_reduction: ClassReduction) -> Self {
|
||||
let state = Default::default();
|
||||
let name = Arc::new(format!("AUROC [{:?}]", class_reduction));
|
||||
|
||||
Self {
|
||||
name: MetricName::new("AUROC".to_string()),
|
||||
state: Default::default(),
|
||||
state,
|
||||
class_reduction,
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
fn binary_auroc(&self, probabilities: &Tensor<1>, targets: &Tensor<1, Int>) -> f64 {
|
||||
let n = targets.dims()[0];
|
||||
/// AUROC metric for binary classification.
|
||||
#[allow(dead_code)]
|
||||
pub fn binary() -> Self {
|
||||
Self::new(ClassReduction::default())
|
||||
}
|
||||
|
||||
let n_pos = targets.clone().sum().into_scalar::<u64>() as usize;
|
||||
/// AUROC metric for multiclass classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `class_reduction` - [Class reduction](ClassReduction) type.
|
||||
#[allow(dead_code)]
|
||||
pub fn multiclass(class_reduction: ClassReduction) -> Self {
|
||||
Self::new(class_reduction)
|
||||
}
|
||||
|
||||
// Early return if we don't have both positive and negative samples
|
||||
if n_pos == 0 || n_pos == n {
|
||||
if n_pos == 0 {
|
||||
log::warn!("Metric cannot be computed because all target values are negative.")
|
||||
} else {
|
||||
log::warn!("Metric cannot be computed because all target values are positive.")
|
||||
}
|
||||
return 0.0;
|
||||
}
|
||||
/// AUROC metric for multi-label classification.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `class_reduction` - [Class reduction](ClassReduction) type.
|
||||
#[allow(dead_code)]
|
||||
pub fn multilabel(class_reduction: ClassReduction) -> Self {
|
||||
Self::new(class_reduction)
|
||||
}
|
||||
|
||||
let pos_mask = targets.clone().equal_elem(1).int().reshape([n, 1]);
|
||||
let neg_mask = targets.clone().equal_elem(0).int().reshape([1, n]);
|
||||
fn pairwise_auc(scores: Tensor<2>, targets: Tensor<2>) -> Tensor<1> {
|
||||
let [n, c] = scores.dims();
|
||||
|
||||
let valid_pairs = pos_mask * neg_mask;
|
||||
let si = scores.clone().reshape([n, 1, c]);
|
||||
let sj = scores.reshape([1, n, c]);
|
||||
|
||||
let prob_i = probabilities.clone().reshape([n, 1]).repeat_dim(1, n);
|
||||
let prob_j = probabilities.clone().reshape([1, n]).repeat_dim(0, n);
|
||||
let yi = targets.clone().reshape([n, 1, c]);
|
||||
let yj = targets.reshape([1, n, c]);
|
||||
|
||||
let correct_order = prob_i.clone().greater(prob_j.clone()).int();
|
||||
let valid: Tensor<3> = yi * (1.0 - yj);
|
||||
|
||||
let ties = prob_i.equal(prob_j).int();
|
||||
let reduce = |t: Tensor<3>| t.sum_dim(0).sum_dim(1).squeeze_dims::<1>(&[0, 1]);
|
||||
|
||||
// Calculate AUC components
|
||||
let num_pairs = valid_pairs.clone().sum().into_scalar::<f64>();
|
||||
let correct_pairs = (correct_order * valid_pairs.clone())
|
||||
.sum()
|
||||
.into_scalar::<f64>();
|
||||
let tied_pairs = (ties * valid_pairs).sum().into_scalar::<f64>();
|
||||
let num_pairs = reduce(valid.clone());
|
||||
let correct_pairs = reduce(si.clone().greater(sj.clone()).float() * valid.clone());
|
||||
let tied_pairs = reduce(si.equal(sj).float() * valid);
|
||||
|
||||
(correct_pairs + 0.5 * tied_pairs) / num_pairs
|
||||
}
|
||||
|
||||
fn compute_auc(&self, predictions: &Tensor<2>, targets: &Tensor<2, Bool>) -> f64 {
|
||||
let [n, c] = predictions.dims();
|
||||
|
||||
let (scores, targets) = match self.class_reduction {
|
||||
ClassReduction::Macro => (predictions.clone(), targets.clone().float()),
|
||||
ClassReduction::Micro => (
|
||||
predictions.clone().reshape([n * c, 1]),
|
||||
targets.clone().float().reshape([n * c, 1]),
|
||||
),
|
||||
};
|
||||
|
||||
let auc = Self::pairwise_auc(scores, targets);
|
||||
|
||||
let keep = auc
|
||||
.clone()
|
||||
.is_nan()
|
||||
.bool_not()
|
||||
.argwhere()
|
||||
.squeeze_dim::<1>(1);
|
||||
|
||||
if keep.dims()[0] == 0 {
|
||||
log::warn!(
|
||||
"AUROC is undefined (no class has both positive and negative samples in the \
|
||||
batch); reporting 0.5 (chance level)."
|
||||
);
|
||||
return 0.5;
|
||||
}
|
||||
|
||||
auc.select(0, keep).mean().into_scalar()
|
||||
}
|
||||
}
|
||||
|
||||
impl Metric for AurocMetric {
|
||||
type Input = AurocInput;
|
||||
type Input = ConfusionStatsInput;
|
||||
|
||||
fn update(&mut self, input: &AurocInput, _metadata: &MetricMetadata) -> SerializedEntry {
|
||||
let [batch_size, num_classes] = input.outputs.dims();
|
||||
fn update(
|
||||
&mut self,
|
||||
input: &ConfusionStatsInput,
|
||||
_metadata: &MetricMetadata,
|
||||
) -> SerializedEntry {
|
||||
let [sample_size, _] = input.predictions.dims();
|
||||
|
||||
assert_eq!(
|
||||
num_classes, 2,
|
||||
"Currently only binary classification is supported"
|
||||
);
|
||||
|
||||
let probabilities = {
|
||||
let exponents = input.outputs.clone().exp();
|
||||
let sum = exponents.clone().sum_dim(1);
|
||||
(exponents / sum)
|
||||
.select(1, Tensor::arange(1..2, &input.outputs.device()))
|
||||
.squeeze_dim(1)
|
||||
};
|
||||
|
||||
let area_under_curve = self.binary_auroc(&probabilities, &input.targets);
|
||||
let metric = self.compute_auc(&input.predictions, &input.targets);
|
||||
|
||||
self.state.update(
|
||||
100.0 * area_under_curve,
|
||||
batch_size,
|
||||
100.0 * metric,
|
||||
sample_size,
|
||||
FormatOptions::new(self.name()).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
@@ -122,59 +158,127 @@ impl Numeric for AurocMetric {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::metric::ClassReduction::{self, *};
|
||||
use burn_core::tensor::{TensorData, Tolerance};
|
||||
use rstest::rstest;
|
||||
|
||||
#[test]
|
||||
fn test_auroc() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::new();
|
||||
/// Inputs and expected AUROC computed with an independent reference
|
||||
/// equivalent to scikit-learn's `roc_auc_score` (Mann-Whitney U:
|
||||
/// `(#pos>neg + 0.5·ties) / (P·N)`, One-vs-Rest, macro/micro). Scores
|
||||
/// are distinct so the statistic is unambiguous and matches sklearn.
|
||||
#[derive(Clone, Copy)]
|
||||
enum Data {
|
||||
Binary,
|
||||
Multiclass,
|
||||
Multilabel,
|
||||
}
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.1, 0.9], // High confidence positive
|
||||
[0.7, 0.3], // Low confidence negative
|
||||
[0.6, 0.4], // Low confidence negative
|
||||
[0.2, 0.8], // High confidence positive
|
||||
],
|
||||
&device,
|
||||
fn input(data: Data) -> ConfusionStatsInput {
|
||||
let dev = Default::default();
|
||||
match data {
|
||||
Data::Binary => ConfusionStatsInput::new(
|
||||
Tensor::from_data([[0.34], [0.64], [0.12], [0.19], [0.53], [0.38]], &dev),
|
||||
Tensor::from_data([[0], [0], [0], [0], [1], [1]], &dev),
|
||||
),
|
||||
Tensor::from_data([1, 0, 0, 1], &device), // True labels
|
||||
Data::Multiclass => ConfusionStatsInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.79, 0.41, 0.16],
|
||||
[0.25, 0.93, 0.78],
|
||||
[0.61, 0.09, 0.21],
|
||||
[0.9, 0.31, 0.33],
|
||||
[0.16, 0.82, 0.57],
|
||||
[0.57, 0.18, 0.63],
|
||||
],
|
||||
&dev,
|
||||
),
|
||||
Tensor::from_data(
|
||||
[
|
||||
[1, 0, 0],
|
||||
[1, 0, 0],
|
||||
[1, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[1, 0, 0],
|
||||
],
|
||||
&dev,
|
||||
),
|
||||
),
|
||||
Data::Multilabel => ConfusionStatsInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.11, 0.57, 0.9],
|
||||
[0.13, 0.66, 0.37],
|
||||
[0.71, 0.85, 0.6],
|
||||
[0.29, 0.69, 0.49],
|
||||
[0.68, 0.45, 0.25],
|
||||
[0.33, 0.36, 0.31],
|
||||
],
|
||||
&dev,
|
||||
),
|
||||
Tensor::from_data(
|
||||
[
|
||||
[1, 1, 1],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[1, 1, 0],
|
||||
[0, 1, 1],
|
||||
[1, 1, 1],
|
||||
],
|
||||
&dev,
|
||||
),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
// Binary is a single column -> Macro == Micro.
|
||||
#[case::binary_macro(Data::Binary, Macro, 0.75)]
|
||||
#[case::binary_micro(Data::Binary, Micro, 0.75)]
|
||||
#[case::multiclass_macro(Data::Multiclass, Macro, 0.5666666666666667)]
|
||||
#[case::multiclass_micro(Data::Multiclass, Micro, 0.6458333333333333)]
|
||||
#[case::multilabel_macro(Data::Multilabel, Macro, 0.2907407407407407)]
|
||||
#[case::multilabel_micro(Data::Multilabel, Micro, 0.3611111111111111)]
|
||||
fn test_auroc(
|
||||
#[case] data: Data,
|
||||
#[case] class_reduction: ClassReduction,
|
||||
#[case] expected: f64,
|
||||
) {
|
||||
let mut metric = AurocMetric::new(class_reduction);
|
||||
|
||||
let _entry = metric.update(&input(data), &MetricMetadata::fake());
|
||||
|
||||
TensorData::from([metric.value().current()])
|
||||
.assert_approx_eq::<f64>(&TensorData::from([expected * 100.0]), Tolerance::default());
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::macro_reduction(Macro)]
|
||||
#[case::micro_reduction(Micro)]
|
||||
fn test_auroc_perfect_separation(#[case] class_reduction: ClassReduction) {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::new(class_reduction);
|
||||
|
||||
let input = ConfusionStatsInput::new(
|
||||
Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
|
||||
Tensor::from_data([[0, 1], [1, 0], [1, 0], [0, 1]], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(metric.value().current(), 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auroc_perfect_separation() {
|
||||
#[rstest]
|
||||
#[case::macro_reduction(Macro)]
|
||||
#[case::micro_reduction(Micro)]
|
||||
fn test_auroc_chance_level(#[case] class_reduction: ClassReduction) {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::new();
|
||||
let mut metric = AurocMetric::new(class_reduction);
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data([[0.0, 1.0], [1.0, 0.0], [1.0, 0.0], [0.0, 1.0]], &device),
|
||||
Tensor::from_data([1, 0, 0, 1], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(metric.value().current(), 100.0); // Perfect AUC
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auroc_random() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::new();
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.5, 0.5], // Random predictions
|
||||
[0.5, 0.5],
|
||||
[0.5, 0.5],
|
||||
[0.5, 0.5],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([1, 0, 0, 1], &device),
|
||||
// All scores tied -> every pair is a tie -> AUROC = 0.5.
|
||||
let input = ConfusionStatsInput::new(
|
||||
Tensor::from_data([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], &device),
|
||||
Tensor::from_data([[0, 1], [1, 0], [1, 0], [0, 1]], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
@@ -182,44 +286,50 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auroc_all_one_class() {
|
||||
fn test_auroc_macro_drops_degenerate_class() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::new();
|
||||
let mut metric = AurocMetric::new(Macro);
|
||||
|
||||
let input = AurocInput::new(
|
||||
// Class 2 never appears (column all-negative) -> its AUROC is undefined
|
||||
// and must be dropped, leaving the two well-separated classes at 1.0.
|
||||
let input = ConfusionStatsInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.1, 0.9], // All positives predictions
|
||||
[0.2, 0.8],
|
||||
[0.3, 0.7],
|
||||
[0.4, 0.6],
|
||||
[0.9, 0.1, 0.0],
|
||||
[0.2, 0.8, 0.0],
|
||||
[0.7, 0.3, 0.0],
|
||||
[0.1, 0.6, 0.0],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([1, 1, 1, 1], &device), // All positive class
|
||||
Tensor::from_data([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0]], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(metric.value().current(), 0.0);
|
||||
assert_eq!(metric.value().current(), 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Currently only binary classification is supported")]
|
||||
fn test_auroc_multiclass_error() {
|
||||
fn test_auroc_all_degenerate_is_chance() {
|
||||
let device = Default::default();
|
||||
let mut metric = AurocMetric::new();
|
||||
let mut metric = AurocMetric::binary();
|
||||
|
||||
let input = AurocInput::new(
|
||||
Tensor::from_data(
|
||||
[
|
||||
[0.1, 0.2, 0.7], // More than 2 classes not supported
|
||||
[0.3, 0.5, 0.2],
|
||||
],
|
||||
&device,
|
||||
),
|
||||
Tensor::from_data([2, 1], &device),
|
||||
// Only positives -> no valid pair in any column -> undefined ->
|
||||
// reported as chance level (0.5).
|
||||
let input = ConfusionStatsInput::new(
|
||||
Tensor::from_data([[0.9], [0.8], [0.7], [0.6]], &device),
|
||||
Tensor::from_data([[1], [1], [1], [1]], &device),
|
||||
);
|
||||
|
||||
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||
assert_eq!(metric.value().current(), 50.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auroc_reduction_changes_name() {
|
||||
let macro_metric = AurocMetric::new(Macro);
|
||||
let micro_metric = AurocMetric::new(Micro);
|
||||
|
||||
assert_ne!(macro_metric.name(), micro_metric.name());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user