Commit Graph

2692 Commits

Author SHA1 Message Date
nathaniel
d4799a9a11 Update revs 2026-03-31 13:29:13 -04:00
Dilshod Tadjibaev
fcc8dacdf4 Fix attention_fallback NaN for fully-masked rows (#4697)
* Fix attention_fallback NaN for fully-masked rows

Clamp softmax max to -1e4 and sum to 1e-6 so that rows where all
positions are -inf (from combined bool + causal masking) produce 0
instead of NaN.

Fixes #4694

* Extract softmax clamp constants and simplify tests

- Replace magic literals with named SOFTMAX_MAX_FLOOR / SOFTMAX_SUM_EPS
- Use TestTensorBool::full for all-true mask instead of ones().greater_elem()
- Remove redundant comment restating assertion messages

* Improve causal test to exercise partial bool + causal mask combination

The test now masks only key 0 via bool mask. Combined with causal
masking, row 0 is fully masked (can only attend to key 0, which is
masked) while rows 1-3 still have valid positions. This exercises
the actual reported failure mode from #4694.

* Fix formatting

* Fix f16 build: use num_traits for type-generic comparisons

Use num_traits::cast for f64-to-FloatElem conversion and
num_traits::Float for abs()/is_nan(), fixing compilation when
FloatElem is f16.
2026-03-31 11:36:37 -05:00
RunjiaChen
faaf89d381 Add HammingWindow operator to burn-tensor (#4698) 2026-03-31 12:03:28 -04:00
Louis Fortier-Dubois
d4d6cec48b update cubek and cubecl (#4699)
* update cubek and cubecl

* put back vecmat cmma
2026-03-31 10:17:37 -04:00
Nathaniel Simard
e5bdeef7fc Fix fusion consistency checks and binding estimation in burn-cubecl-fusion (#4695) 2026-03-31 08:28:49 -04:00
dependabot[bot]
0f133622ab Bump cc from 1.2.57 to 1.2.58 (#4689)
Bumps [cc](https://github.com/rust-lang/cc-rs) from 1.2.57 to 1.2.58.
- [Release notes](https://github.com/rust-lang/cc-rs/releases)
- [Changelog](https://github.com/rust-lang/cc-rs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/cc-rs/compare/cc-v1.2.57...cc-v1.2.58)

---
updated-dependencies:
- dependency-name: cc
  dependency-version: 1.2.58
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 08:01:32 -04:00
dependabot[bot]
9b0a5a02d6 Bump opentelemetry-otlp from 0.31.0 to 0.31.1 (#4690)
Bumps [opentelemetry-otlp](https://github.com/open-telemetry/opentelemetry-rust) from 0.31.0 to 0.31.1.
- [Release notes](https://github.com/open-telemetry/opentelemetry-rust/releases)
- [Changelog](https://github.com/open-telemetry/opentelemetry-rust/blob/main/docs/release_0.30.md)
- [Commits](https://github.com/open-telemetry/opentelemetry-rust/compare/v0.31.0...opentelemetry-otlp-0.31.1)

---
updated-dependencies:
- dependency-name: opentelemetry-otlp
  dependency-version: 0.31.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 07:38:21 -04:00
dependabot[bot]
0f8595ae0f Bump uuid from 1.22.0 to 1.23.0 (#4691)
Bumps [uuid](https://github.com/uuid-rs/uuid) from 1.22.0 to 1.23.0.
- [Release notes](https://github.com/uuid-rs/uuid/releases)
- [Commits](https://github.com/uuid-rs/uuid/compare/v1.22.0...v1.23.0)

---
updated-dependencies:
- dependency-name: uuid
  dependency-version: 1.23.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 07:33:22 -04:00
dependabot[bot]
0f813f955d Bump ctor from 0.6.3 to 0.8.0 (#4692)
Bumps [ctor](https://github.com/mmastrac/rust-ctor) from 0.6.3 to 0.8.0.
- [Commits](https://github.com/mmastrac/rust-ctor/commits)

---
updated-dependencies:
- dependency-name: ctor
  dependency-version: 0.8.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 07:32:01 -04:00
dependabot[bot]
09a8118530 Bump openblas-src from 0.10.14 to 0.10.15 (#4693)
Bumps [openblas-src](https://github.com/blas-lapack-rs/openblas-src) from 0.10.14 to 0.10.15.
- [Release notes](https://github.com/blas-lapack-rs/openblas-src/releases)
- [Changelog](https://github.com/blas-lapack-rs/openblas-src/blob/master/CHANGELOG.md)
- [Commits](https://github.com/blas-lapack-rs/openblas-src/compare/openblas-src-v0.10.14...openblas-src-v0.10.15)

---
updated-dependencies:
- dependency-name: openblas-src
  dependency-version: 0.10.15
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-30 07:31:21 -04:00
Louis Fortier-Dubois
ed72d2b125 Update cubek and fix vecmat autotune (#4682)
* add unit vecmat to vecmat autotune

* wip

* minor
2026-03-27 13:14:38 -04:00
dependabot[bot]
78910cc549 Bump codecov/codecov-action from 5 to 6 (#4681)
Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 5 to 6.
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](https://github.com/codecov/codecov-action/compare/v5...v6)

---
updated-dependencies:
- dependency-name: codecov/codecov-action
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-27 07:50:52 -04:00
Guillaume Lagrange
5466db71bf Ignore local tests with pre-trained weights (#4676)
* Ignore local tests with pre-trained weights

* Fix clippy
2026-03-26 12:17:57 -04:00
Guillaume Lagrange
1fd8cff2af Fix dispatch when only wgpu is enabled (maps to webgpu) (#4678)
* Fix dispatch when only wgpu is enabled (-> webgpu)

* Fix native bool check order
2026-03-26 12:16:54 -04:00
Louis Fortier-Dubois
712c06ee25 update cubek (#4677) 2026-03-26 11:58:00 -04:00
AdrianEddy
6cb763107c Fix fusion kernel vector_size mismatch on f16 output writes (#4675) 2026-03-26 11:57:31 -04:00
Louis Fortier-Dubois
fa9f6815d1 Include new vec2mat routine in matmul autotune (#4673) 2026-03-25 14:54:57 -04:00
Guillaume Lagrange
0c6ac4897d Update cubecl & cubek revs (#4672)
* Update cubecl & cubek revs

* Fix vecmat

* Ignore tests that require pre-trained weights
2026-03-25 12:50:14 -04:00
lif
68a42bb490 feat: add categorical sampling for tensors (#4655)
* feat: add multinomial (categorical distribution) sampling for 2D tensors

Implement `Tensor<B, 2>::multinomial(num_samples, replacement)` using an
inverse CDF approach (cumsum + uniform sampling) that works across all
backends without requiring new trait methods.

Closes #1121

Signed-off-by: majiayu000 <1835304752@qq.com>

* fix: clamp multinomial indices to prevent out-of-bounds from float imprecision

- Clamp sampled indices to [0, num_categories-1] to guard against
  floating-point imprecision in cumsum (matches PyTorch behavior)
- Document undefined behavior for all-zero weight rows
- Add statistical distribution test for non-degenerate probabilities

Signed-off-by: majiayu000 <1835304752@qq.com>

* fix: address review feedback on PR #4655

- Rename multinomial to categorical
- Generalize from Tensor<B, 2> to Tensor<B, D> (last dim = categories)
- Add num_samples == 0 validation
- Remove replacement parameter
- Add 1D and 3D tests

Signed-off-by: majiayu000 <1835304752@qq.com>

* style: fix cargo fmt ordering for categorical module

Signed-off-by: majiayu000 <1835304752@qq.com>

* Fix shape manipulation

* Fix tests
2026-03-25 10:30:23 -04:00
Genna Wingert
5b4b334632 chore: Update to upstream changes in cubecl (#4670)
* Update to register changes

* Fix clippy

* Fix candle

* Reformat
2026-03-25 08:13:21 -04:00
Guillaume Lagrange
a615c61451 Refactor backend tests to set device settings at initialization + use Dispatch (#4666)
* Fix autodiff device checkpointing

* Refactor tests and fix some out dtypes

* Cargo fmt

* Update xtask handle_backend_tests

* Remove debug

* Cleanup

* Fix remote dtype usage

* Feature gate f16 tests

* Fix default quant scheme

* Fix cuda enabled by default in workspace

* Cargo fmt

* Fix clippy

* Fix display

* Fix float elem metal

* Update notes

* Fix burn-store tests
2026-03-24 16:04:19 -04:00
Lee hong
543f0bd2ba Add HannWindow operator to burn-tensor (#4631)
* Add HannWindow operator to burn-tensor

1. crates/burn-tensor/src/tensor/api/float.rs
   - Add Tensor<B, 1>::hann_window(size, periodic, options) as an
     associated function under a new impl<B> Tensor<B, 1> block.
   - Implement the Hann window formula w(n) = 0.5 - 0.5*cos(2*pi*n/N)
     using existing ops (arange, float, cos, mul_scalar, add_scalar, cast).
   - Support periodic (N=size) and symmetric (N=size-1) modes.
   - Handle edge cases: size=0 returns empty tensor, size=1 returns [1.0].
   - Use f64 precision for angular increment calculation.
   - Include cfg_attr doc with LaTeX formula and a plain-text fallback.

2. crates/burn-backend-tests/tests/tensor/float/ops/hann_window.rs
   - Add 6 test cases covering periodic mode, symmetric mode, dtype
     options, empty tensor, and size=1 boundary for both modes.
   - Use assert_approx_eq with Tolerance for cross-backend float stability.

3. crates/burn-backend-tests/tests/tensor/float/ops/mod.rs
   - Register mod hann_window to include the new test module.

* Move to signal module + function
2026-03-24 15:16:19 -04:00
TsaoLun
f0987e783c fixup:(burn-ndarray) fix comment and tidy imports (#4668) 2026-03-24 08:18:02 -04:00
Guillaume Lagrange
699fb9e243 Bump deps (#4665)
* chore(deps): bump tokio-tungstenite from 0.28.0 to 0.29.0

Bumps [tokio-tungstenite](https://github.com/snapview/tokio-tungstenite) from 0.28.0 to 0.29.0.
- [Changelog](https://github.com/snapview/tokio-tungstenite/blob/master/CHANGELOG.md)
- [Commits](https://github.com/snapview/tokio-tungstenite/compare/v0.28.0...v0.29.0)

---
updated-dependencies:
- dependency-name: tokio-tungstenite
  dependency-version: 0.29.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* chore(deps): bump gix-tempfile from 21.0.1 to 21.0.2

Bumps [gix-tempfile](https://github.com/GitoxideLabs/gitoxide) from 21.0.1 to 21.0.2.
- [Release notes](https://github.com/GitoxideLabs/gitoxide/releases)
- [Changelog](https://github.com/GitoxideLabs/gitoxide/blob/main/CHANGELOG.md)
- [Commits](https://github.com/GitoxideLabs/gitoxide/compare/gix-tempfile-v21.0.1...gix-tempfile-v21.0.2)

---
updated-dependencies:
- dependency-name: gix-tempfile
  dependency-version: 21.0.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* chore(deps): bump zip from 8.2.0 to 8.3.1

Bumps [zip](https://github.com/zip-rs/zip2) from 8.2.0 to 8.3.1.
- [Release notes](https://github.com/zip-rs/zip2/releases)
- [Changelog](https://github.com/zip-rs/zip2/blob/master/CHANGELOG.md)
- [Commits](https://github.com/zip-rs/zip2/compare/v8.2.0...v8.3.1)

---
updated-dependencies:
- dependency-name: zip
  dependency-version: 8.3.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Remove debug print + dead code

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-23 09:57:23 -04:00
Guillaume Lagrange
c74a2998dd Fix tch int_zeros dtype in sync (#4664) 2026-03-23 08:59:28 -04:00
Guillaume Lagrange
1e5c220db2 [Breaking] Use device settings to provide output dtype (#4653)
* Use device settings to provide output dtype

* Fix bool scalar

* Use thread_local + switch to once lock w/ already initialized error

* Fix doc

* Fix typo

* Fix no-std

* Update docs

* Use resolved device settings instead

* Fix default bool

* Remove print
2026-03-23 08:49:17 -04:00
Guillaume Lagrange
8f1ac9a7ed Bump rustls-webpki to 0.103.10 and tar to 0.4.45 (#4663)
* Bump rustls-webpki to 0.103.10

* Bump tar to 0.4.45
2026-03-23 08:29:42 -04:00
cong-or
7c3be3be02 feat: add FID vision metric (#4644)
* feat: add FID vision metric (#4312)

  Frechet Inception Distance for evaluating generative image quality.
  InceptionV3 feature extractor with pretrained pytorch-fid weights.

* â—Ź fix: deform_conv2d backward input gradient

  Flatten 2D columns tensor to 1D before into_linear_view() to fix incorrect indexing introduced by View launch refactor (#4639).

* refactor: address FID metric review feedback

  Remove GPU syncs, use linalg::trace, drop unnecessary .expand(),
  unify eps constants.
2026-03-20 14:21:59 -04:00
Sepcnt
8f1c2e0fd1 Add Adan optimizer implementation with tests (#4651) 2026-03-20 11:38:06 -04:00
Guillaume Lagrange
bcaabad860 [Breaking] Add bool store dtype + remove bool elem from fusion (#4649)
* Add bool store dtype + remove bool elem from fusion

* Fix bool display test

* Fix burn-store

* Fix candle

* Fix condition
2026-03-19 12:14:21 -04:00
Louis Fortier-Dubois
c52bcba6e1 Selector/attention (#4648)
* wip

* Flash attention selection

* update rev

* update rev

* clippy
2026-03-18 13:51:00 -04:00
TsaoLun
00d9c38df4 fix(burn-ndarray): use owned storage for native heap allocations in from_data (#4647) 2026-03-18 13:38:46 -04:00
Charles23R
f424a080cd add utilities fn to FusionServer (#4640)
* add utilities fn to FusionServer

* cubecl version

* git hash

* lockfile

* sync versions

* commit hash
2026-03-18 12:57:08 -04:00
Guillaume Lagrange
763980e94a Remove int powf and make powi numeric op (#4646)
* Remove int powf and make powi numeric op

* Fix fmt
2026-03-18 09:51:04 -04:00
Guillaume Lagrange
6ef54b2a6b Bump fake to 5.1 + update cube Cargo.lock rev (#4643)
* Bump fake to 5.1

* Update Cargo.lock
2026-03-18 09:09:08 -04:00
Genna Wingert
b962e3f409 refactor: View launch (#4639) 2026-03-18 09:03:44 -04:00
Genna Wingert
832f73b718 chore: Update to cubecl changes (#4630)
* Remove R::supported_line_sizes

* refactor: Metadata optimization

* Revert temp fix

* Rename `ShapeError` to `MetadataError`

* Cleanup

* Bump cubecl and cubek rev

* Fix doc test

* Refactor `CubeOption`

* Migrated fusion

* Migrated cubecl fusion

* WIP

* WIP

* It compiles

* WIP

* Update to changes

* WIP

* WIP

* Burn cubecl

* WIP

* WIP

* WIP

* WIP

* FIX

* Update rev

* Clippy

* Update rev

* Update revs

* Once againt, update the rev to fix no-std

* Fix test compilation issue

* Migrate to runtime option refactor

* Update rev

* Debug + give names to threads

* Update stuff

* Remove prints

* Add 64-bit test for manually enabling and running

* Update to new cubecl rev

* Add rev

* Set rev

* WIP

* Update to line generic refactor

* Rename `Line` to `Vector`

* Rename `line` to `vector`

* FIx fusion

* Disable compilation log

* Update main

* Bump cubecl rev

* Use marker instead of usize for `ElemExpand`

* Rename types for clarity

* Cleanup

* Update cubecl rev

* Fix clippy
2026-03-17 11:14:08 -04:00
Guillaume Lagrange
b37673aeec Dispatch autodiff checkpointing strategy support (#4629)
* Support autodiff checkpointing strategy

* Debug msg

* Update docs

* Fix bool tensor default dtype
2026-03-16 10:45:10 -04:00
cong-or
65c79c198b Implement RNNT loss (#4623)
* feat(nn): add RNNT loss function

* Clean up RNNT loss tests: deduplicate assert_close, rename verbose helper

* Deduplicate assert_close, rename verbose helper, fix clippy too-many-args

* refactor(nn): address RNNT loss review feedback

  - Rename fused_log_softmax to logits
  - Hoist logit_lengths reshape/expand outside forward loop
  - Replace assert_close with assert_approx_eq

* fix(nn): apply cargo fmt to RNNT loss module
2026-03-16 07:50:44 -04:00
dependabot[bot]
ac06e3c2e4 Bump portable-atomic-util from 0.2.5 to 0.2.6 (#4634)
Bumps [portable-atomic-util](https://github.com/taiki-e/portable-atomic-util) from 0.2.5 to 0.2.6.
- [Release notes](https://github.com/taiki-e/portable-atomic-util/releases)
- [Changelog](https://github.com/taiki-e/portable-atomic-util/blob/main/CHANGELOG.md)
- [Commits](https://github.com/taiki-e/portable-atomic-util/compare/v0.2.5...v0.2.6)

---
updated-dependencies:
- dependency-name: portable-atomic-util
  dependency-version: 0.2.6
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-16 07:50:13 -04:00
github-actions[bot]
5e29b3e2fd Combined PRs (#4637)
* Bump tracel-xtask from 4.13.4 to 4.13.5

Bumps [tracel-xtask](https://github.com/tracel-ai/xtask) from 4.13.4 to 4.13.5.
- [Release notes](https://github.com/tracel-ai/xtask/releases)
- [Commits](https://github.com/tracel-ai/xtask/compare/v4.13.4...v4.13.5)

---
updated-dependencies:
- dependency-name: tracel-xtask
  dependency-version: 4.13.5
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump fake from 4.4.0 to 5.0.0

Bumps [fake](https://github.com/cksac/fake-rs) from 4.4.0 to 5.0.0.
- [Release notes](https://github.com/cksac/fake-rs/releases)
- [Commits](https://github.com/cksac/fake-rs/compare/v4.4.0...v5.0.0)

---
updated-dependencies:
- dependency-name: fake
  dependency-version: 5.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump tracing-subscriber from 0.3.22 to 0.3.23

Bumps [tracing-subscriber](https://github.com/tokio-rs/tracing) from 0.3.22 to 0.3.23.
- [Release notes](https://github.com/tokio-rs/tracing/releases)
- [Commits](https://github.com/tokio-rs/tracing/compare/tracing-subscriber-0.3.22...tracing-subscriber-0.3.23)

---
updated-dependencies:
- dependency-name: tracing-subscriber
  dependency-version: 0.3.23
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

* Bump clap from 4.5.60 to 4.6.0

Bumps [clap](https://github.com/clap-rs/clap) from 4.5.60 to 4.6.0.
- [Release notes](https://github.com/clap-rs/clap/releases)
- [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md)
- [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.60...clap_complete-v4.6.0)

---
updated-dependencies:
- dependency-name: clap
  dependency-version: 4.6.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2026-03-16 07:44:24 -04:00
Guillaume Lagrange
4e962f9484 Remove named tensor (#4628) 2026-03-13 15:38:00 -04:00
Nathaniel Simard
3e9a17baf6 Perf: Improve fusion score (#4511)
* Some cleanup

* Fix tests

* Cleanup

* Clippy

* Fix test

* Remove concurrency issue
2026-03-13 15:05:18 -04:00
Genna Wingert
fce1a57211 refactor: Vector size generic (#4624) 2026-03-13 08:36:34 -04:00
Softmaximalist
b7a08e1a95 Fix function arg name inconsistencies (#4626) 2026-03-13 08:05:27 -04:00
Softmaximalist
d6a8fba1f2 Update building-blocks chapter (#4625) 2026-03-12 08:25:22 -04:00
Nathaniel Simard
3c6b710dca Refactor/device handle (#4593) 2026-03-11 18:46:49 -04:00
Dmitry Patsura
3cd6e90e41 feat: Introduce Lanczos3 interpolation method (#4601)
* feat: Introduce Lanczos3 interpolation method

* chore

* chore: fmt/clippy

* chore: analog in pytorch

* chore: improve burn book?

* chore: update comment

* chore: test_1d_lanczos3

* chore: test test_upsample_2x

* feat: add weight normalization to lanczos3 interpolation and test_upsample_2x test

Normalize kernel weights in both ndarray and cubecl lanczos3 implementations
to match standard behavior (TF/JAX/PIL). Add test_upsample_2x for 4x4->8x8
with align_corners=true.

* fix: skip OOB positions in lanczos3 instead of clamping to match TF/JAX/PIL

When sampling near edges, clamping out-of-bounds positions to the edge pixel
double-counted that pixel. Skipping OOB positions and renormalizing over
in-bounds weights matches TF/JAX/PIL behavior.
2026-03-11 12:59:55 -04:00
Softmaximalist
c013482047 Add Gram Matrix Loss for vision tasks (#4595)
* Implement Gram Matrix Loss

* Update computation of normalization factor

* Update forward method to sum layer losses all at once

* Update comments and expect() messages for weights downloading

* Format code
2026-03-11 12:42:15 -04:00
Guillaume Lagrange
176fcaaef9 Fix fusion cumulative op inputs (#4621)
* Fix cumulative op inputs

* Fix audit
2026-03-10 16:51:57 -04:00