mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
feat(cube): interpolate nearest exact mode (#4982)
* add nearest exact interpolate mode * update main hash * Update cubecl dependencies to new revision * update cubek --------- Co-authored-by: louisfd <louisfd94@gmail.com>
This commit is contained in:
92
Cargo.lock
generated
92
Cargo.lock
generated
@@ -2117,7 +2117,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"cubecl-core",
|
||||
"cubecl-cpu",
|
||||
@@ -2133,7 +2133,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-common"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"bytemuck",
|
||||
@@ -2174,7 +2174,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-core"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"bitflags 2.11.1",
|
||||
"bytemuck",
|
||||
@@ -2203,7 +2203,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-cpp"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
@@ -2219,7 +2219,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-cpu"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
@@ -2240,7 +2240,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-cuda"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
@@ -2258,7 +2258,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-hip"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
@@ -2287,7 +2287,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-ir"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"cubecl-common",
|
||||
@@ -2311,7 +2311,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-macros"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"darling 0.23.0",
|
||||
@@ -2328,7 +2328,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-macros-internal"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"darling 0.23.0",
|
||||
"proc-macro2",
|
||||
@@ -2339,7 +2339,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-opt"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-core",
|
||||
@@ -2357,7 +2357,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-runtime"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"async-channel",
|
||||
@@ -2388,7 +2388,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-spirv"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"bitflags 2.11.1",
|
||||
"cubecl-common",
|
||||
@@ -2404,7 +2404,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-std"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-core",
|
||||
@@ -2420,7 +2420,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-wgpu"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"ash",
|
||||
"async-channel",
|
||||
@@ -2439,7 +2439,6 @@ dependencies = [
|
||||
"hashbrown 0.17.1",
|
||||
"itertools 0.14.0",
|
||||
"log",
|
||||
"renderdoc",
|
||||
"sanitize-filename",
|
||||
"tracel-ash",
|
||||
"tracing",
|
||||
@@ -2450,7 +2449,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-zspace"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=a8ddb49dc928d0ea6cfc86101013082e24a5ff18#a8ddb49dc928d0ea6cfc86101013082e24a5ff18"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=0751cf6a5005c1755c1cd44d1c9383e7b1de814c#0751cf6a5005c1755c1cd44d1c9383e7b1de814c"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"serde",
|
||||
@@ -2460,7 +2459,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubek-attention",
|
||||
@@ -2478,7 +2477,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-attention"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
@@ -2492,7 +2491,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-convolution"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
@@ -2508,7 +2507,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-fft"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
]
|
||||
@@ -2516,7 +2515,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-interpolate"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -2526,7 +2525,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-matmul"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
@@ -2539,7 +2538,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-pool"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -2549,7 +2548,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-quant"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -2560,7 +2559,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-random"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -2573,7 +2572,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-reduce"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubek-std",
|
||||
@@ -2586,7 +2585,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-std"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=629cec7f784a211e15290863ed71f68d4b59c25c#629cec7f784a211e15290863ed71f68d4b59c25c"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=0d259b53ff086b727dc1cdf001d064585bbc78bc#0d259b53ff086b727dc1cdf001d064585bbc78bc"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -3507,15 +3506,6 @@ dependencies = [
|
||||
"zlib-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "float-cmp"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "float-cmp"
|
||||
version = "0.10.0"
|
||||
@@ -6993,7 +6983,7 @@ version = "0.22.0-pre.1"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-store",
|
||||
"float-cmp 0.10.0",
|
||||
"float-cmp",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@@ -7517,21 +7507,6 @@ version = "1.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
|
||||
|
||||
[[package]]
|
||||
name = "renderdoc"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac633a08f39bf3268714799bd8c4a1a19c19c203d817d3448bb8b91c97817cf0"
|
||||
dependencies = [
|
||||
"bitflags 2.11.1",
|
||||
"float-cmp 0.9.0",
|
||||
"libloading 0.8.9",
|
||||
"once_cell",
|
||||
"renderdoc-sys",
|
||||
"winapi",
|
||||
"wio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "renderdoc-sys"
|
||||
version = "1.1.0"
|
||||
@@ -7820,7 +7795,7 @@ dependencies = [
|
||||
"burn-autodiff",
|
||||
"burn-flex",
|
||||
"burn-store",
|
||||
"float-cmp 0.10.0",
|
||||
"float-cmp",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@@ -10417,15 +10392,6 @@ version = "0.0.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904"
|
||||
|
||||
[[package]]
|
||||
name = "wio"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d129932f4644ac2396cb456385cbf9e63b5b30c6e8dc4820bdca4eb082037a5"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.51.0"
|
||||
|
||||
@@ -218,10 +218,10 @@ burn-vision = { path = "crates/burn-vision", version = "0.22.0-pre.1", default-f
|
||||
burn-wgpu = { path = "crates/burn-wgpu", version = "0.22.0-pre.1", default-features = false }
|
||||
|
||||
### For the main burn branch. ###
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a8ddb49dc928d0ea6cfc86101013082e24a5ff18" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a8ddb49dc928d0ea6cfc86101013082e24a5ff18" }
|
||||
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a8ddb49dc928d0ea6cfc86101013082e24a5ff18" }
|
||||
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "629cec7f784a211e15290863ed71f68d4b59c25c" }
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0751cf6a5005c1755c1cd44d1c9383e7b1de814c" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0751cf6a5005c1755c1cd44d1c9383e7b1de814c" }
|
||||
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0751cf6a5005c1755c1cd44d1c9383e7b1de814c" }
|
||||
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "0d259b53ff086b727dc1cdf001d064585bbc78bc" }
|
||||
### For local development. ###
|
||||
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
|
||||
|
||||
@@ -294,6 +294,9 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
|
||||
.tensor
|
||||
.upsample_nearest2d(output_size[0], output_size[1])
|
||||
.unwrap(),
|
||||
InterpolateMode::NearestExact => {
|
||||
panic!("nearest exact interpolation is not implemented in Candle backend")
|
||||
}
|
||||
InterpolateMode::Bilinear => {
|
||||
panic!("bilinear interpolation is not supported by Candle")
|
||||
}
|
||||
|
||||
@@ -34,8 +34,8 @@ use cubek::{
|
||||
BlueprintStrategy, Routine,
|
||||
double_buffering::{CyclicDoubleBufferingAlgorithm, DoubleBufferingArgs},
|
||||
double_unit::DoubleUnitAlgorithm,
|
||||
gemm::GemmRoutine,
|
||||
gemv_innerproduct::{DoubleVecMatInnerProductAlgorithm, VecMatInnerProductAlgorithm},
|
||||
gemv_plane_parallel::GemvPlaneParallelRoutine,
|
||||
gemv_unit_perpendicular::GemvUnitPerpendicularRoutine,
|
||||
ordered_double_buffering::{OrderedDoubleBufferingAlgorithm, OrderedSelectionArgs},
|
||||
simple::{SimpleAlgorithm, SimpleArgs},
|
||||
@@ -225,7 +225,7 @@ pub enum FusedMatmulSelector {
|
||||
},
|
||||
SimpleVecMat,
|
||||
DoubleVecMat,
|
||||
GemvPlaneParallel,
|
||||
GemmNoStage,
|
||||
GemvUnitPerpendicular,
|
||||
SimpleUnit,
|
||||
DoubleUnit,
|
||||
@@ -254,7 +254,7 @@ impl FusedMatmulSelector {
|
||||
}
|
||||
FusedMatmulSelector::SimpleVecMat => "simple_vec_mat".into(),
|
||||
FusedMatmulSelector::DoubleVecMat => "double_buffering_vec_mat".into(),
|
||||
FusedMatmulSelector::GemvPlaneParallel => "gemv_plane_parallel".into(),
|
||||
FusedMatmulSelector::GemmNoStage => "gemm".into(),
|
||||
FusedMatmulSelector::GemvUnitPerpendicular => "gemv_unit_perpendicular".into(),
|
||||
FusedMatmulSelector::SimpleUnit => "simple_unit".into(),
|
||||
FusedMatmulSelector::DoubleUnit => "double_buffering_unit".into(),
|
||||
@@ -639,8 +639,8 @@ impl FusedMatmulLaunch<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
FusedMatmulSelector::GemvPlaneParallel => {
|
||||
match launch_inner_fix_dtype::<R, GemvPlaneParallelRoutine>(
|
||||
FusedMatmulSelector::GemmNoStage => {
|
||||
match launch_inner_fix_dtype::<R, GemmRoutine>(
|
||||
client,
|
||||
FusedMatmulInputLaunch::new(
|
||||
inputs,
|
||||
|
||||
@@ -139,7 +139,7 @@ pub fn fused_matmul_autotune<R: Runtime>(
|
||||
for (selector, double_buf) in [
|
||||
(FusedMatmulSelector::SimpleVecMat, false),
|
||||
(FusedMatmulSelector::DoubleVecMat, true),
|
||||
(FusedMatmulSelector::GemvPlaneParallel, false),
|
||||
(FusedMatmulSelector::GemmNoStage, false),
|
||||
(FusedMatmulSelector::GemvUnitPerpendicular, false),
|
||||
] {
|
||||
set = set.with(
|
||||
|
||||
@@ -7,7 +7,8 @@ use crate::{
|
||||
use burn_backend::{Shape, TensorMetadata, ops::InterpolateMode, ops::InterpolateOptions};
|
||||
use cubek::interpolate::{
|
||||
definition::InterpolateMode as CubekInterpolateMode,
|
||||
definition::InterpolateOptions as CubekInterpolateOptions, interpolate as cubek_interpolate,
|
||||
definition::InterpolateOptions as CubekInterpolateOptions,
|
||||
definition::NearestMode as CubekNearestMode, interpolate as cubek_interpolate,
|
||||
interpolate_backward as cubek_interpolate_backward,
|
||||
};
|
||||
|
||||
@@ -91,7 +92,10 @@ fn map_options(options: InterpolateOptions) -> CubekInterpolateOptions {
|
||||
CubekInterpolateOptions {
|
||||
mode: {
|
||||
match options.mode {
|
||||
InterpolateMode::Nearest => CubekInterpolateMode::Nearest,
|
||||
InterpolateMode::Nearest => CubekInterpolateMode::Nearest(CubekNearestMode::Floor),
|
||||
InterpolateMode::NearestExact => {
|
||||
CubekInterpolateMode::Nearest(CubekNearestMode::Exact)
|
||||
}
|
||||
InterpolateMode::Bilinear => CubekInterpolateMode::Bilinear,
|
||||
InterpolateMode::Bicubic => CubekInterpolateMode::Bicubic,
|
||||
InterpolateMode::Lanczos3 => CubekInterpolateMode::Lanczos3,
|
||||
|
||||
@@ -14,8 +14,9 @@ use cubek::matmul::{
|
||||
launch::{MatmulAutotuneKey, MatmulGlobalScale, Strategy, should_tune_double_buffering},
|
||||
routines::{
|
||||
BlueprintStrategy, TileSizeSelection, double_buffering::DoubleBufferingArgs,
|
||||
double_unit::DoubleUnitSelectionArgs, ordered_double_buffering::OrderedSelectionArgs,
|
||||
simple::SimpleArgs, simple_unit::SimpleUnitSelectionArgs,
|
||||
double_unit::DoubleUnitSelectionArgs, gemm::GemmStrategy,
|
||||
ordered_double_buffering::OrderedSelectionArgs, simple::SimpleArgs,
|
||||
simple_unit::SimpleUnitSelectionArgs,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -187,7 +188,7 @@ pub fn matmul_autotune<R: CubeRuntime>(
|
||||
false,
|
||||
),
|
||||
(
|
||||
Strategy::GemvPlaneParallel(BlueprintStrategy::Inferred(Default::default())),
|
||||
Strategy::Gemm(BlueprintStrategy::Inferred(Default::default())),
|
||||
false,
|
||||
),
|
||||
(
|
||||
@@ -239,6 +240,22 @@ pub fn matmul_autotune<R: CubeRuntime>(
|
||||
}
|
||||
}
|
||||
|
||||
// Gemm no stage
|
||||
// In unit because not accelerated
|
||||
let gemm_no_stage_strategy = Strategy::Gemm(BlueprintStrategy::Inferred(GemmStrategy {
|
||||
target_num_planes: None,
|
||||
}));
|
||||
set = set.with(
|
||||
Tunable::new(
|
||||
&gemm_no_stage_strategy.to_string(),
|
||||
move |(lhs, rhs, out)| {
|
||||
launch_matmul::<R>(&gemm_no_stage_strategy, lhs, rhs, out)
|
||||
.map_err(|err| format!("{err:?}"))
|
||||
},
|
||||
)
|
||||
.group(&unit, move |_key| PRIORITY_MAX),
|
||||
);
|
||||
|
||||
// Accelerated matmuls
|
||||
for (strategy, double_buf, group_extra, tile_group) in [
|
||||
(
|
||||
|
||||
@@ -1669,6 +1669,7 @@ pub struct MaxPool2dWithIndicesBackwardOpIr {
|
||||
#[allow(missing_docs)]
|
||||
pub enum InterpolateModeIr {
|
||||
Nearest,
|
||||
NearestExact,
|
||||
Bilinear,
|
||||
Bicubic,
|
||||
Lanczos3,
|
||||
@@ -1834,6 +1835,7 @@ impl From<InterpolateModeIr> for InterpolateMode {
|
||||
fn from(val: InterpolateModeIr) -> Self {
|
||||
match val {
|
||||
InterpolateModeIr::Nearest => Self::Nearest,
|
||||
InterpolateModeIr::NearestExact => Self::NearestExact,
|
||||
InterpolateModeIr::Bilinear => Self::Bilinear,
|
||||
InterpolateModeIr::Bicubic => Self::Bicubic,
|
||||
InterpolateModeIr::Lanczos3 => Self::Lanczos3,
|
||||
@@ -1851,6 +1853,7 @@ impl From<InterpolateMode> for InterpolateModeIr {
|
||||
fn from(val: InterpolateMode) -> Self {
|
||||
match val {
|
||||
InterpolateMode::Nearest => Self::Nearest,
|
||||
InterpolateMode::NearestExact => Self::Nearest,
|
||||
InterpolateMode::Bilinear => Self::Bilinear,
|
||||
InterpolateMode::Bicubic => Self::Bicubic,
|
||||
InterpolateMode::Lanczos3 => Self::Lanczos3,
|
||||
|
||||
@@ -293,6 +293,9 @@ where
|
||||
)
|
||||
.into())
|
||||
}
|
||||
InterpolateMode::NearestExact => {
|
||||
panic!("nearest exact interpolation is not supported for ndarray backend")
|
||||
}
|
||||
InterpolateMode::Bilinear => {
|
||||
let align_corners = options.align_corners;
|
||||
module_op!(inp(x), opt(), E, |x| bilinear_interpolate::<E>(
|
||||
@@ -333,6 +336,9 @@ where
|
||||
InterpolateMode::Nearest => module_op!(inp(x, grad), opt(), E, |x, grad| {
|
||||
nearest_interpolate_backward::<E>(x, grad, output_size).into()
|
||||
}),
|
||||
InterpolateMode::NearestExact => {
|
||||
panic!("nearest exact interpolation backward is not supported for ndarray backend")
|
||||
}
|
||||
InterpolateMode::Bilinear => {
|
||||
panic!("bilinear interpolation backward is not supported for ndarray backend")
|
||||
}
|
||||
|
||||
@@ -14,11 +14,17 @@ use burn::tensor::ops::InterpolateMode as OpsInterpolateMode;
|
||||
/// This enum defines different interpolation modes for resampling data.
|
||||
#[derive(Config, Debug)]
|
||||
pub enum InterpolateMode {
|
||||
/// Nearest-neighbor interpolation
|
||||
/// Nearest-neighbor floor interpolation
|
||||
///
|
||||
/// This mode selects the value with floor mapping to a sample point for each output pixel.
|
||||
/// It is applicable for both temporal and spatial data.
|
||||
Nearest,
|
||||
|
||||
/// Nearest-neighbor exact interpolation
|
||||
///
|
||||
/// This mode selects the value of the nearest sample point for each output pixel.
|
||||
/// It is applicable for both temporal and spatial data.
|
||||
Nearest,
|
||||
NearestExact,
|
||||
|
||||
/// Linear interpolation
|
||||
///
|
||||
@@ -49,6 +55,7 @@ impl From<InterpolateMode> for OpsInterpolateMode {
|
||||
fn from(mode: InterpolateMode) -> Self {
|
||||
match mode {
|
||||
InterpolateMode::Nearest => OpsInterpolateMode::Nearest,
|
||||
InterpolateMode::NearestExact => OpsInterpolateMode::NearestExact,
|
||||
InterpolateMode::Linear => OpsInterpolateMode::Bilinear,
|
||||
InterpolateMode::Cubic => OpsInterpolateMode::Bicubic,
|
||||
InterpolateMode::Lanczos => OpsInterpolateMode::Lanczos3,
|
||||
|
||||
@@ -203,13 +203,18 @@ impl UnfoldOptions {
|
||||
}
|
||||
}
|
||||
|
||||
/// Algorithm used for upsampling.
|
||||
/// Algorithm used.
|
||||
#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
|
||||
pub enum InterpolateMode {
|
||||
/// Nearest-neighbor interpolation.
|
||||
/// Nearest-neighbor floor interpolation.
|
||||
/// Matches the legacy behavior of OpenCV’s INTER_NEAREST. It results in a bottom-right shift when resizing.
|
||||
/// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
|
||||
Nearest,
|
||||
|
||||
/// Nearest-neighbor exact interpolation.
|
||||
/// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
|
||||
NearestExact,
|
||||
|
||||
/// Bilinear interpolation.
|
||||
/// <https://en.wikipedia.org/wiki/Bilinear_interpolation>
|
||||
Bilinear,
|
||||
@@ -226,7 +231,7 @@ pub enum InterpolateMode {
|
||||
/// Interpolation options.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterpolateOptions {
|
||||
/// Algorithm used for upsampling.
|
||||
/// Algorithm used.
|
||||
pub mode: InterpolateMode,
|
||||
/// If `true`, the input and output tensors are aligned by their corner pixels.
|
||||
/// If `false`, half-pixel coordinate mapping is used instead.
|
||||
|
||||
@@ -397,6 +397,9 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
|
||||
InterpolateMode::Nearest => {
|
||||
tch::Tensor::upsample_nearest2d(&x.tensor, output_size, None, None)
|
||||
}
|
||||
InterpolateMode::NearestExact => {
|
||||
panic!("nearest exact interpolation is not supported by PyTorch/tch backend")
|
||||
}
|
||||
InterpolateMode::Bilinear => {
|
||||
tch::Tensor::upsample_bilinear2d(&x.tensor, output_size, align_corners, None, None)
|
||||
}
|
||||
@@ -430,6 +433,11 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
|
||||
None,
|
||||
None,
|
||||
),
|
||||
InterpolateMode::NearestExact => {
|
||||
panic!(
|
||||
"nearest exact interpolation backward is not supported by PyTorch/tch backend"
|
||||
)
|
||||
}
|
||||
InterpolateMode::Bilinear => tch::Tensor::upsample_bilinear2d_backward(
|
||||
&grad.tensor,
|
||||
output_size,
|
||||
|
||||
Reference in New Issue
Block a user