2859 Commits

Author SHA1 Message Date
Darwin Boersma
0194761460 fix: include RMS_NORM in normalization layer detection for safetensors adapter (#5023) 2026-05-29 17:10:26 -04:00
Guillaume Lagrange
bb0242c6ca refactor(tensor): seal TensorKind and Parameter traits and consolidate param impls (#5022) 2026-05-29 16:23:59 -04:00
Jean-Pierre De Jesus DIAZ
45267d0aee fix(train): use recv_timeout instead of try_recv. (#5021)
In the train-renderer thread try_recv is called in the loop without
a sleep condition causing the CPU to spin at full speed mostly
calling try_recv.

This changes it so that recv_timeout waits for the remainder of the
tick and then proceed to render if necessary.

This showed up as ~20% CPU time spent during profiling with perf.
2026-05-29 09:08:50 -04:00
John Kaczman
cc9664db43 perf(flex): bulk-copy inner-contiguous run in to_contiguous (#5019) 2026-05-29 08:24:00 -04:00
Guillaume Lagrange
1b6dd60361 fix(flex): max_pool3d_backward indices type (#5017)
* fix(flex): max_pool3d_backward indices type

* Correctly handle tests
2026-05-28 13:02:11 -04:00
Guillaume Lagrange
cbf016064e refactor(tensor)!: replace TensorKind::id() with const KIND and rename TensorKindId -> Kind (#5018) 2026-05-28 13:01:45 -04:00
Yichi Zhang
c19d2c3b25 fix(doc): inconsistent assertion for non-negative (#5016) 2026-05-28 11:12:47 -04:00
rfi-irfos
f0669572af feat: add Calibration::AbsMean for BitNet b1.58 ternary weight quantization (#4989)
* quantization: add Calibration::Ternary for BitNet b1.58 weight quantization

Extends the Calibration enum with a Ternary variant that maps
{-1, 0, +1} weights to a symmetric [-γ, +γ] range where γ is
mean(|W|) (BitNet b1.58 §3.1) or a caller-supplied threshold.

Intended for use with QuantValue::Q2S + QuantStore::PackedU32,
which packs 16 ternary weights per u32 — 16× smaller than f32,
4× smaller than int8.

  let layer = Linear::new(cfg, &device)
      .quantize(&QuantScheme::default()
          .with_value(QuantValue::Q2S)
          .with_store(QuantStore::PackedU32(0)),
          &Calibration::Ternary { threshold: None });

Adds three tests (per-tensor auto, per-tensor explicit, per-block auto)
to the existing calibration test module.

* quantization: rename Calibration::Ternary to AbsMean, drop threshold field

Per laggui's review: the calibration strategy is absolute-mean (mean(|W|)),
not "ternary" — ternary refers to the quantization values, not the range
finder. Remove the optional threshold override field since it doesn't fit
the calibration-method abstraction. Fix rustfmt failures on long lines.

* Apply suggestions from code review

* calibration tests: restore block level on abs_mean per_block test

The previous suggestion drop accidentally removed .with_level(QuantLevel::block([4]))
from abs_mean_calibration_range_per_block and left the let-binding missing its
semicolon, which broke the build. Restore the block level and terminate the
statement to match min_max_calibration_range_per_block.
2026-05-28 11:02:28 -04:00
David M.
6b721d887c refactor(train)!: transform all Progress items into a global progress struct (#5012)
* feat: add TrainingProgressLogger

* feat: add a debug implementation of TrainingProgressLogger

* feat: expose new optional progress_logger

* feat: rename trait function end_eval to end_epoch and update mnist example

* feat: add EvaluationProgressLogger trait

* fix: fix evaluation_logger override

* doc: add documentation for EvaluationProgressLogger

* feat: refactor training and evaluation logger to be more general and structured

* feat: put the new methods in full.rs pipeline and evaluation pipeline

* feat: put the new methods in full.rs pipeline and evaluation pipeline

* chore: removed boolean for FullEventTrainingProcessor

* chore: remove is_test_started from EvaluationEventPorcessor and add TestStart Event

* feat: add EndEpoch event and add TrainingProgressLogger in minimal.rs

* doc: updated doc

* doc: update documentation in pardigm.rs

* fix: address pr comments

* feat: update method signature from new traits

* fix: change output directory in mnist and add total_test to  EvalEven::start

* feat: add EndTest event and refactor evaluation pipeline to ugrade logging

* fix: reset default values in mnist example

* feat: add new event of TrainingProgressLogger to ddp algorithm

* fix: doc test and lint

* refactor: Change Training and Evaluation progress loggers parameters in trait methods

* refactor: Renderer now uses Progress logger traits for training

* refactor: add renderer calls in all events in full.rs

* refactor: replace all Progress structs for GlobalProgress

* refactor: implemnt OverallProgressLogger in burn-train

* refactor: implemnt OverallProgressLogger in burn-train

* refactor: remove global_progress of TrainingItem

* refactor: add progressEvent and counters in TrainingprogressLogger implementations

* refactor: add log_event to EvaluationProgressLogger

* refactor: add ProgressLogger and events to RL

* refactor: erased ProgressEvents and put strings instead

* fix: put mnist example back to normal

* fix: format

* fix: erased old test artifacts

* fix: erase useless character

* fix : erase duplicated line of code

* fix: correct documentation

* fix: change method call to be more logic

* fix: change label name for consistency

* fix: change method call

* fix: commented one line to remove rl from default features

* refactor: change stafulness from eventProcessor to Implementations

* refactor: change stafulness from eventProcessor to Implementations

* fix: formatting

* fix: address comments

* fix: fix CI
2026-05-28 10:15:48 -04:00
Liheng Yuan
0ba32109a3 added FloatKind::F64 arm for FuseType implementation (#5011) 2026-05-27 13:12:05 -04:00
Genna Wingert
08a12343fd refactor(cubecl): add lifetime to View (#4999)
* refactor: Migrate to changes to references

* Fix rebase

* Port to lifetime views

* Turn off compilation log

* Update to upstream changes

* Cleanup

* Update rev
2026-05-27 07:56:31 -04:00
Guillaume Lagrange
ed4d313b16 fix(dispatch): use the correct webgpu/vulkan/metal/wgpu device (#5010)
* Fix dispatch device for wgpu/webgpu/vulkan/metal and add other cubecl backends to vision extension

* Fix DeviceOps feature gate and vulkan f64
2026-05-26 14:21:30 -04:00
Charles23R
1b8c5321d4 refactor(rl): add inference device and update dqn example (#5009)
* update dqn-agent example

* update device stuff + couple bugs in dispatch

* update main and some utils

* deal with the whole to_device thing for off policy

* remove test code + fix warnings

* burn toml file
2026-05-26 10:07:26 -04:00
Guillaume Lagrange
c6ad760c2e fix(remote): add remote device settings fetched during client init (#5008)
* Add remote device settings fetched during client init

* Fix Device::wgpu doc

* Edit comment
2026-05-25 15:36:06 -04:00
Nathaniel Simard
cd8145ee15 feat(wgpu): use WgpuRuntime compiler generic to enable specialized aliases (#5001)
* WIP

* Add docs

* Fix feature flags

* Fix device type

* Fix compilation

* Update revs

* Fmt
2026-05-25 14:02:04 -04:00
Puneet Dixit
70660fedce fix(ndarray): broadcast remainder operands (#5002) 2026-05-25 09:41:52 -04:00
SamuelBelanger
412aa66542 chore(cubek): update to interpolate refactor (#5003)
* refactor interpolate kernel

* update cargo.toml
2026-05-25 09:34:45 -04:00
Guillaume Lagrange
e18a3c2e30 refactor(backends)!: remove associated element types & replace with device defaults (#5000)
* Remove multi backend router

* Remove associated element types; replaced by device defaults

* Fix wgpu devices

* Fix burn-store

* Fix cubecl backend extensions + device configs

* Missing fixes

* Fix clippy

* Fix wgpu backend feature gate in dispatch

* Fix wgpu/vulkan features for xtask tests

* Fix merge
2026-05-25 09:04:44 -04:00
Luca Cappelletti
34b6abd65d feat(train): add mouse support to TUI metric navigation (#4998)
* feat(train): add mouse support to TUI metric navigation

* feat(train): make TUI tab-strip chevrons clickable

* refactor(train): bundle TUI text hit-state and gate redraws

* fix(train): restore TUI key handling on Windows
2026-05-25 08:20:08 -04:00
Fatih Jawwad
bc4280dfbd fix(docs): improve error messages with shape/dimension context (#4996) 2026-05-25 08:08:53 -04:00
Nathaniel Simard
0034f64bf0 feat(dispatch): add remote backend (#4994) 2026-05-22 17:36:42 -04:00
Luca Cappelletti
746b934765 fix(train): scroll TUI metric tabs to keep selected visible (#4995) 2026-05-22 10:01:04 -04:00
Nathaniel Simard
9ae1cab154 refactor: use obfuscate from burn-std (#4997)
* Use obfuscate from burn-std

* migrate transaction
2026-05-22 09:34:34 -04:00
Louis Fortier-Dubois
d5e145f7fa Update CubeCL & CubeK (#4992) 2026-05-22 08:43:34 -04:00
David M.
6a5b0e64ce feat(train): add training and evaluation progress logger and add them to the event processors (#4980)
* feat: add TrainingProgressLogger

* feat: add a debug implementation of TrainingProgressLogger

* feat: expose new optional progress_logger

* feat: rename trait function end_eval to end_epoch and update mnist example

* feat: add EvaluationProgressLogger trait

* fix: fix evaluation_logger override

* doc: add documentation for EvaluationProgressLogger

* feat: refactor training and evaluation logger to be more general and structured

* feat: put the new methods in full.rs pipeline and evaluation pipeline

* feat: put the new methods in full.rs pipeline and evaluation pipeline

* chore: removed boolean for FullEventTrainingProcessor

* chore: remove is_test_started from EvaluationEventPorcessor and add TestStart Event

* feat: add EndEpoch event and add TrainingProgressLogger in minimal.rs

* doc: updated doc

* doc: update documentation in pardigm.rs

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: address pr comments

* feat: update method signature from new traits

* fix: change output directory in mnist and add total_test to  EvalEven::start

* feat: add EndTest event and refactor evaluation pipeline to ugrade logging

* fix: reset default values in mnist example

* feat: add new event of TrainingProgressLogger to ddp algorithm

* fix: doc test and lint
2026-05-21 15:29:57 -04:00
Nathaniel Simard
8037309704 perf(core)!: improve compile times via opaque inner types to break dependency chain (#4977)
* Initial commit

* Improve compilation time

* Refactor device

* Device blob

* Improvements

* Autodiff fix

* Burn-std no cubecl dep

* WIP

* Autodiff

* Sanitize outside of generic

* Fix display compilation

* Clippy

* Improve device selection

* WIP

* Fix examples wip

* Fix fmt

* Cleanup

* Fmt

* Fix clippy

* Fix autodiff

* Fix test device

* Fix docs

* Fix tests

* Fmt

* Fix display

* Miri fix

* Improve comments

* Fix device

* Fix docs
2026-05-21 14:08:30 -04:00
Fatih Jawwad
28478e81fc fix(doc): correct SgdConfig::init (#4986) 2026-05-21 09:57:14 -04:00
Sai Asish Y
2a86f615e4 fix(train): guard TUI metric navigation against empty state (#4987)
The TUI dashboard's Left/Right arrow handlers call next_metric and
previous_metric, which do (selected + 1) % data.len() and
data.len() - 1 unconditionally. Before the first metric is recorded,
data is empty, so the first arrow press panics with 'remainder with
a divisor of zero' (Right) or 'subtract overflow' (Left), killing
the TUI render thread.

Return early when data is empty in both methods, matching the empty
state that view() already handles by returning NumericMetricView::None.

Add a regression test that exercises both methods on the default
(empty) state.
2026-05-21 08:27:57 -04:00
Crutcher Dunnavant
bef58e2186 fix(features): expose safetensors/pytorch support directly (#4985) 2026-05-21 07:55:07 -04:00
SamuelBelanger
ca57d897ca feat(cube): interpolate nearest exact mode (#4982)
* add nearest exact interpolate mode

* update main hash

* Update cubecl dependencies to new revision

* update cubek

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
2026-05-20 17:17:47 -04:00
Guillaume Lagrange
611fc3e796 fix(fusion): resolve tensor into_ir (re-entrancy bug) (#4984)
* Fix fusion resolve tensor into_ir (re-entrancy bug)

* Add test
2026-05-20 17:00:00 -04:00
Guillaume Lagrange
58194b93aa fix(dispatch): Wgpu and WebGpu re-export feature gating (#4981) 2026-05-20 12:38:49 -04:00
Guillaume Lagrange
2f5dabe332 fix(deps): update enumset to 1.1.13 (#4979) 2026-05-20 11:03:31 -04:00
Genna Wingert
b9c6bdb3dc refactor(cubecl): update to reference changes (#4974)
* refactor: Migrate to changes to references

* Fix rebase

* Fixup

* Fix fusion
2026-05-20 11:02:39 -04:00
cofinite
a2261ae62a feat(ops): autodiff for rfft/irfft (#4956)
* autodiff for rfft/irfft

* removed `ignore` on tests as requested

* feature-gated tests, added test for dim nonzero

* fixed bug when n is Some + more tests

* cargo fmt + typo fix

* Fix tests cond
2026-05-20 10:14:04 -04:00
manuel couto pintos
fb83cb5148 feat(metric)!: add multiclass and multi-label support to AUROC metric (#4960)
AUROC was binary-only and panicked on >2 classes. It now supports
binary, multiclass and multi-label classification via a One-vs-Rest
decomposition aggregated with a Micro/Macro class reduction.

- AurocMetric now uses `ConfusionStatsInput` (like Precision/Recall/
  FBetaScore), so it is implicitly adapted for both
  `ClassificationOutput` and `MultiLabelClassificationOutput`.
- Add `RankingMetricConfig` (threshold-free config) + `From<
  ClassificationMetricConfig>`; AUROC has no decision rule by design.
- Public constructors: `binary()`, `multiclass(ClassReduction)`,
  `multilabel(ClassReduction)`. `new` is now private.
- Degenerate classes (no positive/negative pair) yield NaN and are
  dropped from the Macro mean; all-degenerate -> 0.0 with a warning.
- Remove the now-dead `AurocInput` type and its `Adaptor` impl.
- Update the metric book accordingly.

BREAKING CHANGE:
- `AurocInput` removed. Use the built-in `ClassificationOutput` /
  `MultiLabelClassificationOutput` adaptors, or `ConfusionStatsInput`.
- `AurocMetric::new()` (no-arg) removed -> use `AurocMetric::binary()`.
- Displayed metric name changed: "AUROC" -> "AUROC [Macro|Micro]".
2026-05-20 10:08:35 -04:00
Fatih Jawwad
bfb0978107 feat(metric): add ROUGE-L score metric (#4967)
* feat(metric): add ROUGE-L score metric

* chore: formatting import sequences

* Remove redundant total_f1 update in rouge.rs
2026-05-20 10:06:55 -04:00
Guillaume Lagrange
f6754f53a4 fix(nn): conv1d initializer fan out (#4973) 2026-05-19 14:47:56 -04:00
Nathaniel Simard
4c2aa989dd fix(fusion): shared tensor (#4962)
* Fix

* Fix

* Simplify

* Cleanup and add more tests

* Resolve cleanup

* Add docs
2026-05-19 08:58:34 -04:00
Redhawk
5aac831e5c burn 0.21 cite (#4969) 2026-05-19 08:38:16 -04:00
Guillaume Lagrange
73b72bc92a refactor(extension): feature gate Tensor::from/into_primitive and add from/into_bridge w/ TensorKindId validation (#4961)
* Dispatch default_backend should not be enabled for wgpu_only

* Add extension feature with from/into primitive

* Fix usage in test

* Fix doc
2026-05-15 16:29:55 -04:00
Truffle
275df7e748 Re-export BurnConfig with fusion/autodiff getters (#4959)
Closes #4932.

BurnConfig now ships in the burn umbrella alongside the runtime_config function,
and its inner FusionConfig/AutodiffConfig sub-trees are reached via fusion() and
autodiff() getters instead of public fields.
2026-05-15 14:32:44 -04:00
Guilhem Ané
9b65417e0b Bool tensor API improvements (#4955)
* Generic into tensor data for from_bool, like from_float

* Impl bool ops for bool tensors

* Remove backend generic

* Fixed doc snippets
2026-05-15 14:11:48 -04:00
Guillaume Lagrange
d30a82ae78 refactor(tensor): add BridgeTensor to bridge high-level tensor API with dispatch (#4958)
* Replace tensor primitive type

* Fix into_scalar calls

* Fix clippy

* Fix unused + bounds

* Fix extensions

* Rename to BridgeTensor

* Cargo fmt
2026-05-15 10:59:33 -04:00
June
2fef95179a feat(metric): add BLEU score training metric (#4937)
* Add BLEU score training metric (#544)

Implements sentence-level BLEU (Bilingual Evaluation Understudy) as a
training metric, following the existing CER/WER pattern.

- Configurable max n-gram order via `with_max_n()` (default: BLEU-4)
- Pad token stripping via `with_pad_token()`
- Modified n-gram precision with brevity penalty per Papineni et al.
- 9 tests covering perfect/zero/partial match, brevity penalty,
  padding, batching, bigrams, clear, and naming

* Address review: constructor API, corpus-style batch BLEU, smoothing

- Make `with_max_n` the primary constructor (matches TopKAccuracy pattern)
- Default name now `BLEU-4` instead of `BLEU`
- Accumulate n-gram counts across batch (corpus-style) instead of
  averaging per-sentence scores
- Return 0 for short references without smoothing (standard BLEU)
- Add `BleuSmoothing` enum: None, AddEpsilon, Exponential (Chen & Cherry 2014)
- Document epoch-level aggregation limitation with TODO

* address review: stateful k for exponential smoothing

Per laggui's comment, SacreBLEU's smoothing method 3 uses a stateful
multiplier that doubles for every n-gram order with zero matches, not
the n-gram order itself. Update Exponential smoothing to track a
running multiplier as in Chen & Cherry (2014) / SacreBLEU.

Refs: https://aclanthology.org/W14-3346.pdf
      https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/bleu.py#L282

* Migrate BLEU to backendless Tensor signature (#4717)

Drop the B: Backend generic from BleuScore and BleuInput; switch to
the new Tensor<2, Int> signature introduced by #4717. Mirrors the
shape now used by sister metrics CER/WER.

Tested with cargo test --lib bleu -p burn-train (13/13 pass).

* Cargo fmt
2026-05-15 09:59:44 -04:00
Guillaume Lagrange
27964e0160 Refactor tensor kind traits (#4957)
* Remove generics from tensor kind traits and Elem associated type

* Move backend extension, re-export no types from burn_backend and fix into_scalar missing types

* Remove into/from primitive usage in burn-optim

* Fix backtrace

* Fix float_sort_with_indices int dtype

* Fix feature gated

* Remove dead code

* Add tensor primitive note

* Another note

* High level kind

* Working but still pub traits

* Fix into_scalar doc example

* Fix argsort out dtype

* Restrict ops traits to pub(crate)

* Add other backends for burn-vision
2026-05-14 14:02:44 -04:00
Nathaniel Simard
0c763dcd26 Refactor/isolate burn backend deps (#4954)
* Wip Tensor module moving

* WIP

* WIP

* Fix

* Fix dependencies

* Cleanup

* Fixes

* Fix docs

* Remove registered reference

* Fix refs

* set_default_dtypes is in burn-tensor

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
2026-05-13 16:26:14 -04:00
Guillaume Lagrange
924d5d9ee9 Feature gate the backend ops extension when no backend feature is enabled (#4953) 2026-05-13 13:24:08 -04:00
Guillaume Lagrange
dbf03c516b [Breaking] Remove Tensor backend generic and add high-level Device struct (#4717)
* Remove backend generic from `Tensor` and add high-level `Device` struct

* Fix rfft

* WIP modules

* Cleanup flags

* Carry autodiff intent via checkpointing field

* Fix ctc loss test shape

* Fix imports

* Fix burn-collective & burn-communication

* Working core / nn / optim

* Removed from burn-rl

* Removed from burn-train

* Update lock

* Working mnist

* More examples

* Fix int/bool to_device with float AD device + add enumerate devices

* Working text-classification

* Fixed MultiGradientsParams usage + added gradient checkpointing to trainer

* Updated text-generation example

* Update comment

* Fix merge

* Remove backend generics from merge

* Update lock

* Remove from burn-store

* Fix burn-store dtype tests

* Remove circular dep to burn-tensor from burn-flex

* Cleanup

* Fix burn-store tests

* Fix seed

* Fix imports

* More examples

* Add `backend_extension` proc macro

* Working vision extension

* Missing merge changes

* More merge fixes

* tree

* Whoops

* Cleanup

* Not feature gated

* Fix wgpu kernel example

* Small change

* Update lock

* Cleanup

* Cargo fmt

* Clippy

* Remove burn-collective

* Cleanup

* Fix burn-store import

* Fallback to wgpu for workspace builds & blanket fusion impl

* Add `default_backend` fallback for burn-dispatch

* Gate unused var autodiff

* Fix no-std

* Refactor new tests

* Fix more tests

* Adjust ctc_loss backward f16 tolerance

* Fix test device for burn-core

* Fix test device

* Fix some docs

* Default to ndarray for now due to burn-flex CAS limitation

* Fix docs + ndarray default

* Add portable_atomic_util::Arc in burn-autodiff

* Add optim feature flag because it requires autodiff (fix no-std; not supported on all targets)

* Add Gradients high-level wrapper

* Fix types

* Fix test device

* Fix docs

* Fix burn-train auto merge

* Fix clippy

* WIP burn device override

* Fix dtype usage

* Remove test-metal

* Fix CI test flags

* Remove burn-train vision on metal

* Small note

* Missing metal feature for macos CI

* Fix docs

* Remove backend ops note

* Change set_default_dtypes to &mut self

* Fix mut device

* Fix burn-vision backend features

* Exclude burn-backend-tests from workspace tests (explicitly handled)
2026-05-13 10:03:12 -04:00
Guillaume Lagrange
3035c1bc27 Make RL event types mod public (#4951) 2026-05-12 12:33:34 -04:00