feat(cube): interpolate nearest exact mode (#4982)

* add nearest exact interpolate mode

* update main hash

* Update cubecl dependencies to new revision

* update cubek

---------

Co-authored-by: louisfd <louisfd94@gmail.com>
This commit is contained in:
SamuelBelanger
2026-05-20 17:17:47 -04:00
committed by GitHub
parent 611fc3e796
commit ca57d897ca
12 changed files with 102 additions and 83 deletions

92
Cargo.lock generated
View File

@@ -2117,7 +2117,7 @@ dependencies = [
[[package]]
name = "cubecl"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"cubecl-core",
"cubecl-cpu",
@@ -2133,7 +2133,7 @@ dependencies = [
[[package]]
name = "cubecl-common"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"backtrace",
"bytemuck",
@@ -2174,7 +2174,7 @@ dependencies = [
[[package]]
name = "cubecl-core"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"bitflags 2.11.1",
"bytemuck",
@@ -2203,7 +2203,7 @@ dependencies = [
[[package]]
name = "cubecl-cpp"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"bytemuck",
"cubecl-common",
@@ -2219,7 +2219,7 @@ dependencies = [
[[package]]
name = "cubecl-cpu"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"bytemuck",
"cubecl-common",
@@ -2240,7 +2240,7 @@ dependencies = [
[[package]]
name = "cubecl-cuda"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"bytemuck",
"cubecl-common",
@@ -2258,7 +2258,7 @@ dependencies = [
[[package]]
name = "cubecl-hip"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"bytemuck",
"cubecl-common",
@@ -2287,7 +2287,7 @@ dependencies = [
[[package]]
name = "cubecl-ir"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"bumpalo",
"cubecl-common",
@@ -2311,7 +2311,7 @@ dependencies = [
[[package]]
name = "cubecl-macros"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"cubecl-common",
"darling 0.23.0",
@@ -2328,7 +2328,7 @@ dependencies = [
[[package]]
name = "cubecl-macros-internal"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"darling 0.23.0",
"proc-macro2",
@@ -2339,7 +2339,7 @@ dependencies = [
[[package]]
name = "cubecl-opt"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"cubecl-common",
"cubecl-core",
@@ -2357,7 +2357,7 @@ dependencies = [
[[package]]
name = "cubecl-runtime"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"ahash",
"async-channel",
@@ -2388,7 +2388,7 @@ dependencies = [
[[package]]
name = "cubecl-spirv"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"bitflags 2.11.1",
"cubecl-common",
@@ -2404,7 +2404,7 @@ dependencies = [
[[package]]
name = "cubecl-std"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"cubecl-common",
"cubecl-core",
@@ -2420,7 +2420,7 @@ dependencies = [
[[package]]
name = "cubecl-wgpu"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"ash",
"async-channel",
@@ -2439,7 +2439,6 @@ dependencies = [
"hashbrown 0.17.1",
"itertools 0.14.0",
"log",
"renderdoc",
"sanitize-filename",
"tracel-ash",
"tracing",
@@ -2450,7 +2449,7 @@ dependencies = [
[[package]]
name = "cubecl-zspace"
version = "0.11.0-pre.1"
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
dependencies = [
"derive-new",
"serde",
@@ -2460,7 +2459,7 @@ dependencies = [
[[package]]
name = "cubek"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"cubecl",
"cubek-attention",
@@ -2478,7 +2477,7 @@ dependencies = [
[[package]]
name = "cubek-attention"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"bytemuck",
"cubecl",
@@ -2492,7 +2491,7 @@ dependencies = [
[[package]]
name = "cubek-convolution"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"bytemuck",
"cubecl",
@@ -2508,7 +2507,7 @@ dependencies = [
[[package]]
name = "cubek-fft"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"cubecl",
]
@@ -2516,7 +2515,7 @@ dependencies = [
[[package]]
name = "cubek-interpolate"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"cubecl",
"cubecl-common",
@@ -2526,7 +2525,7 @@ dependencies = [
[[package]]
name = "cubek-matmul"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"bytemuck",
"cubecl",
@@ -2539,7 +2538,7 @@ dependencies = [
[[package]]
name = "cubek-pool"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"cubecl",
"cubecl-common",
@@ -2549,7 +2548,7 @@ dependencies = [
[[package]]
name = "cubek-quant"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"cubecl",
"cubecl-common",
@@ -2560,7 +2559,7 @@ dependencies = [
[[package]]
name = "cubek-random"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"cubecl",
"cubecl-common",
@@ -2573,7 +2572,7 @@ dependencies = [
[[package]]
name = "cubek-reduce"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"cubecl",
"cubek-std",
@@ -2586,7 +2585,7 @@ dependencies = [
[[package]]
name = "cubek-std"
version = "0.3.0-pre.1"
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
dependencies = [
"cubecl",
"cubecl-common",
@@ -3507,15 +3506,6 @@ dependencies = [
"zlib-rs",
]
[[package]]
name = "float-cmp"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4"
dependencies = [
"num-traits",
]
[[package]]
name = "float-cmp"
version = "0.10.0"
@@ -6993,7 +6983,7 @@ version = "0.22.0-pre.1"
dependencies = [
"burn",
"burn-store",
"float-cmp 0.10.0",
"float-cmp",
"serde",
]
@@ -7517,21 +7507,6 @@ version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
[[package]]
name = "renderdoc"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac633a08f39bf3268714799bd8c4a1a19c19c203d817d3448bb8b91c97817cf0"
dependencies = [
"bitflags 2.11.1",
"float-cmp 0.9.0",
"libloading 0.8.9",
"once_cell",
"renderdoc-sys",
"winapi",
"wio",
]
[[package]]
name = "renderdoc-sys"
version = "1.1.0"
@@ -7820,7 +7795,7 @@ dependencies = [
"burn-autodiff",
"burn-flex",
"burn-store",
"float-cmp 0.10.0",
"float-cmp",
"serde",
]
@@ -10417,15 +10392,6 @@ version = "0.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904"
[[package]]
name = "wio"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d129932f4644ac2396cb456385cbf9e63b5b30c6e8dc4820bdca4eb082037a5"
dependencies = [
"winapi",
]
[[package]]
name = "wit-bindgen"
version = "0.51.0"

View File

@@ -218,10 +218,10 @@ burn-vision = { path = "crates/burn-vision", version = "0.22.0-pre.1", default-f
burn-wgpu = { path = "crates/burn-wgpu", version = "0.22.0-pre.1", default-features = false }
### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a8ddb49dc928d0ea6cfc86101013082e24a5ff18" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a8ddb49dc928d0ea6cfc86101013082e24a5ff18" }
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a8ddb49dc928d0ea6cfc86101013082e24a5ff18" }
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "629cec7f784a211e15290863ed71f68d4b59c25c" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0751cf6a5005c1755c1cd44d1c9383e7b1de814c" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0751cf6a5005c1755c1cd44d1c9383e7b1de814c" }
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0751cf6a5005c1755c1cd44d1c9383e7b1de814c" }
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "0d259b53ff086b727dc1cdf001d064585bbc78bc" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

View File

@@ -294,6 +294,9 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
.tensor
.upsample_nearest2d(output_size[0], output_size[1])
.unwrap(),
InterpolateMode::NearestExact => {
panic!("nearest exact interpolation is not implemented in Candle backend")
}
InterpolateMode::Bilinear => {
panic!("bilinear interpolation is not supported by Candle")
}

View File

@@ -34,8 +34,8 @@ use cubek::{
BlueprintStrategy, Routine,
double_buffering::{CyclicDoubleBufferingAlgorithm, DoubleBufferingArgs},
double_unit::DoubleUnitAlgorithm,
gemm::GemmRoutine,
gemv_innerproduct::{DoubleVecMatInnerProductAlgorithm, VecMatInnerProductAlgorithm},
gemv_plane_parallel::GemvPlaneParallelRoutine,
gemv_unit_perpendicular::GemvUnitPerpendicularRoutine,
ordered_double_buffering::{OrderedDoubleBufferingAlgorithm, OrderedSelectionArgs},
simple::{SimpleAlgorithm, SimpleArgs},
@@ -225,7 +225,7 @@ pub enum FusedMatmulSelector {
},
SimpleVecMat,
DoubleVecMat,
GemvPlaneParallel,
GemmNoStage,
GemvUnitPerpendicular,
SimpleUnit,
DoubleUnit,
@@ -254,7 +254,7 @@ impl FusedMatmulSelector {
}
FusedMatmulSelector::SimpleVecMat => "simple_vec_mat".into(),
FusedMatmulSelector::DoubleVecMat => "double_buffering_vec_mat".into(),
FusedMatmulSelector::GemvPlaneParallel => "gemv_plane_parallel".into(),
FusedMatmulSelector::GemmNoStage => "gemm".into(),
FusedMatmulSelector::GemvUnitPerpendicular => "gemv_unit_perpendicular".into(),
FusedMatmulSelector::SimpleUnit => "simple_unit".into(),
FusedMatmulSelector::DoubleUnit => "double_buffering_unit".into(),
@@ -639,8 +639,8 @@ impl FusedMatmulLaunch<'_> {
}
}
FusedMatmulSelector::GemvPlaneParallel => {
match launch_inner_fix_dtype::<R, GemvPlaneParallelRoutine>(
FusedMatmulSelector::GemmNoStage => {
match launch_inner_fix_dtype::<R, GemmRoutine>(
client,
FusedMatmulInputLaunch::new(
inputs,

View File

@@ -139,7 +139,7 @@ pub fn fused_matmul_autotune<R: Runtime>(
for (selector, double_buf) in [
(FusedMatmulSelector::SimpleVecMat, false),
(FusedMatmulSelector::DoubleVecMat, true),
(FusedMatmulSelector::GemvPlaneParallel, false),
(FusedMatmulSelector::GemmNoStage, false),
(FusedMatmulSelector::GemvUnitPerpendicular, false),
] {
set = set.with(

View File

@@ -7,7 +7,8 @@ use crate::{
use burn_backend::{Shape, TensorMetadata, ops::InterpolateMode, ops::InterpolateOptions};
use cubek::interpolate::{
definition::InterpolateMode as CubekInterpolateMode,
definition::InterpolateOptions as CubekInterpolateOptions, interpolate as cubek_interpolate,
definition::InterpolateOptions as CubekInterpolateOptions,
definition::NearestMode as CubekNearestMode, interpolate as cubek_interpolate,
interpolate_backward as cubek_interpolate_backward,
};
@@ -91,7 +92,10 @@ fn map_options(options: InterpolateOptions) -> CubekInterpolateOptions {
CubekInterpolateOptions {
mode: {
match options.mode {
InterpolateMode::Nearest => CubekInterpolateMode::Nearest,
InterpolateMode::Nearest => CubekInterpolateMode::Nearest(CubekNearestMode::Floor),
InterpolateMode::NearestExact => {
CubekInterpolateMode::Nearest(CubekNearestMode::Exact)
}
InterpolateMode::Bilinear => CubekInterpolateMode::Bilinear,
InterpolateMode::Bicubic => CubekInterpolateMode::Bicubic,
InterpolateMode::Lanczos3 => CubekInterpolateMode::Lanczos3,

View File

@@ -14,8 +14,9 @@ use cubek::matmul::{
launch::{MatmulAutotuneKey, MatmulGlobalScale, Strategy, should_tune_double_buffering},
routines::{
BlueprintStrategy, TileSizeSelection, double_buffering::DoubleBufferingArgs,
double_unit::DoubleUnitSelectionArgs, ordered_double_buffering::OrderedSelectionArgs,
simple::SimpleArgs, simple_unit::SimpleUnitSelectionArgs,
double_unit::DoubleUnitSelectionArgs, gemm::GemmStrategy,
ordered_double_buffering::OrderedSelectionArgs, simple::SimpleArgs,
simple_unit::SimpleUnitSelectionArgs,
},
};
@@ -187,7 +188,7 @@ pub fn matmul_autotune<R: CubeRuntime>(
false,
),
(
Strategy::GemvPlaneParallel(BlueprintStrategy::Inferred(Default::default())),
Strategy::Gemm(BlueprintStrategy::Inferred(Default::default())),
false,
),
(
@@ -239,6 +240,22 @@ pub fn matmul_autotune<R: CubeRuntime>(
}
}
// Gemm no stage
// In unit because not accelerated
let gemm_no_stage_strategy = Strategy::Gemm(BlueprintStrategy::Inferred(GemmStrategy {
target_num_planes: None,
}));
set = set.with(
Tunable::new(
&gemm_no_stage_strategy.to_string(),
move |(lhs, rhs, out)| {
launch_matmul::<R>(&gemm_no_stage_strategy, lhs, rhs, out)
.map_err(|err| format!("{err:?}"))
},
)
.group(&unit, move |_key| PRIORITY_MAX),
);
// Accelerated matmuls
for (strategy, double_buf, group_extra, tile_group) in [
(

View File

@@ -1669,6 +1669,7 @@ pub struct MaxPool2dWithIndicesBackwardOpIr {
#[allow(missing_docs)]
pub enum InterpolateModeIr {
Nearest,
NearestExact,
Bilinear,
Bicubic,
Lanczos3,
@@ -1834,6 +1835,7 @@ impl From<InterpolateModeIr> for InterpolateMode {
fn from(val: InterpolateModeIr) -> Self {
match val {
InterpolateModeIr::Nearest => Self::Nearest,
InterpolateModeIr::NearestExact => Self::NearestExact,
InterpolateModeIr::Bilinear => Self::Bilinear,
InterpolateModeIr::Bicubic => Self::Bicubic,
InterpolateModeIr::Lanczos3 => Self::Lanczos3,
@@ -1851,6 +1853,7 @@ impl From<InterpolateMode> for InterpolateModeIr {
fn from(val: InterpolateMode) -> Self {
match val {
InterpolateMode::Nearest => Self::Nearest,
InterpolateMode::NearestExact => Self::Nearest,
InterpolateMode::Bilinear => Self::Bilinear,
InterpolateMode::Bicubic => Self::Bicubic,
InterpolateMode::Lanczos3 => Self::Lanczos3,

View File

@@ -293,6 +293,9 @@ where
)
.into())
}
InterpolateMode::NearestExact => {
panic!("nearest exact interpolation is not supported for ndarray backend")
}
InterpolateMode::Bilinear => {
let align_corners = options.align_corners;
module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
@@ -333,6 +336,9 @@ where
InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
nearest_interpolate_backward::<E>(x, grad, output_size).into()
}),
InterpolateMode::NearestExact => {
panic!("nearest exact interpolation backward is not supported for ndarray backend")
}
InterpolateMode::Bilinear => {
panic!("bilinear interpolation backward is not supported for ndarray backend")
}

View File

@@ -14,11 +14,17 @@ use burn::tensor::ops::InterpolateMode as OpsInterpolateMode;
/// This enum defines different interpolation modes for resampling data.
#[derive(Config, Debug)]
pub enum InterpolateMode {
/// Nearest-neighbor interpolation
/// Nearest-neighbor floor interpolation
///
/// This mode selects the value with floor mapping to a sample point for each output pixel.
/// It is applicable for both temporal and spatial data.
Nearest,
/// Nearest-neighbor exact interpolation
///
/// This mode selects the value of the nearest sample point for each output pixel.
/// It is applicable for both temporal and spatial data.
Nearest,
NearestExact,
/// Linear interpolation
///
@@ -49,6 +55,7 @@ impl From<InterpolateMode> for OpsInterpolateMode {
fn from(mode: InterpolateMode) -> Self {
match mode {
InterpolateMode::Nearest => OpsInterpolateMode::Nearest,
InterpolateMode::NearestExact => OpsInterpolateMode::NearestExact,
InterpolateMode::Linear => OpsInterpolateMode::Bilinear,
InterpolateMode::Cubic => OpsInterpolateMode::Bicubic,
InterpolateMode::Lanczos => OpsInterpolateMode::Lanczos3,

View File

@@ -203,13 +203,18 @@ impl UnfoldOptions {
}
}
/// Algorithm used for upsampling.
/// Algorithm used.
#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
pub enum InterpolateMode {
/// Nearest-neighbor interpolation.
/// Nearest-neighbor floor interpolation.
/// Matches the legacy behavior of OpenCVs INTER_NEAREST. It results in a bottom-right shift when resizing.
/// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
Nearest,
/// Nearest-neighbor exact interpolation.
/// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
NearestExact,
/// Bilinear interpolation.
/// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
Bilinear,
@@ -226,7 +231,7 @@ pub enum InterpolateMode {
/// Interpolation options.
#[derive(Debug, Clone)]
pub struct InterpolateOptions {
/// Algorithm used for upsampling.
/// Algorithm used.
pub mode: InterpolateMode,
/// If `true`, the input and output tensors are aligned by their corner pixels.
/// If `false`, half-pixel coordinate mapping is used instead.

View File

@@ -397,6 +397,9 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
InterpolateMode::Nearest => {
tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
}
InterpolateMode::NearestExact => {
panic!("nearest exact interpolation is not supported by PyTorch/tch backend")
}
InterpolateMode::Bilinear => {
tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, align_corners, None, None)
}
@@ -430,6 +433,11 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
None,
None,
),
InterpolateMode::NearestExact => {
panic!(
"nearest exact interpolation backward is not supported by PyTorch/tch backend"
)
}
InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
&grad.tensor,
output_size,