mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
feat(wgpu): use WgpuRuntime compiler generic to enable specialized aliases (#5001)
* WIP * Add docs * Fix feature flags * Fix device type * Fix compilation * Update revs * Fmt
This commit is contained in:
66
Cargo.lock
generated
66
Cargo.lock
generated
@@ -2119,7 +2119,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"cubecl-core",
|
||||
"cubecl-cpu",
|
||||
@@ -2135,7 +2135,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-common"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"bytemuck",
|
||||
@@ -2176,7 +2176,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-core"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"bitflags 2.11.1",
|
||||
"bytemuck",
|
||||
@@ -2205,7 +2205,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-cpp"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
@@ -2221,7 +2221,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-cpu"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
@@ -2242,7 +2242,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-cuda"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
@@ -2260,7 +2260,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-hip"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl-common",
|
||||
@@ -2289,7 +2289,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-ir"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"cubecl-common",
|
||||
@@ -2313,7 +2313,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-macros"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"darling 0.23.0",
|
||||
@@ -2330,7 +2330,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-macros-internal"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"darling 0.23.0",
|
||||
"proc-macro2",
|
||||
@@ -2341,7 +2341,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-opt"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-core",
|
||||
@@ -2359,7 +2359,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-runtime"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"async-channel",
|
||||
@@ -2390,7 +2390,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-spirv"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"bitflags 2.11.1",
|
||||
"cubecl-common",
|
||||
@@ -2406,7 +2406,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-std"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"cubecl-common",
|
||||
"cubecl-core",
|
||||
@@ -2422,7 +2422,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-wgpu"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"ash",
|
||||
"async-channel",
|
||||
@@ -2451,7 +2451,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubecl-zspace"
|
||||
version = "0.11.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=9a783fdef4997d18c6333600c049f898b0b4fc7c#9a783fdef4997d18c6333600c049f898b0b4fc7c"
|
||||
source = "git+https://github.com/tracel-ai/cubecl?rev=b5983f200457271d7e336f0db903eda54834b924#b5983f200457271d7e336f0db903eda54834b924"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"serde",
|
||||
@@ -2461,7 +2461,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubek-attention",
|
||||
@@ -2479,7 +2479,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-attention"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
@@ -2493,7 +2493,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-convolution"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
@@ -2509,7 +2509,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-fft"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
]
|
||||
@@ -2517,7 +2517,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-interpolate"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -2528,7 +2528,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-matmul"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"cubecl",
|
||||
@@ -2541,7 +2541,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-pool"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -2551,7 +2551,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-quant"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -2562,7 +2562,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-random"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -2575,7 +2575,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-reduce"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubek-std",
|
||||
@@ -2588,7 +2588,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "cubek-std"
|
||||
version = "0.3.0-pre.1"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=b4a62bd93b677f83f95108abb5bb893be75729d7#b4a62bd93b677f83f95108abb5bb893be75729d7"
|
||||
source = "git+https://github.com/tracel-ai/cubek?rev=deba7f003af37a231e38bebda293a89df338f467#deba7f003af37a231e38bebda293a89df338f467"
|
||||
dependencies = [
|
||||
"cubecl",
|
||||
"cubecl-common",
|
||||
@@ -4869,9 +4869,9 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682"
|
||||
|
||||
[[package]]
|
||||
name = "jiff"
|
||||
version = "0.2.24"
|
||||
version = "0.2.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d"
|
||||
checksum = "6835eea34fb6321b9b3aa7b685c2b433948c09447e389dc017fdf687d5d11e65"
|
||||
dependencies = [
|
||||
"jiff-static",
|
||||
"log",
|
||||
@@ -4882,9 +4882,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "jiff-static"
|
||||
version = "0.2.24"
|
||||
version = "0.2.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7"
|
||||
checksum = "3c22e04db9c58f5136eb1757f3d5c49a7b187f49e52185228cbd2f5acdfcc08c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -5132,9 +5132,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.29"
|
||||
version = "0.4.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
|
||||
checksum = "616ec5685824bcc94416c6d4a7a446eea774a31efd7062c8480ba6fd06d7a6e5"
|
||||
|
||||
[[package]]
|
||||
name = "loop9"
|
||||
|
||||
@@ -219,10 +219,10 @@ burn-vision = { path = "crates/burn-vision", version = "0.22.0-pre.1", default-f
|
||||
burn-wgpu = { path = "crates/burn-wgpu", version = "0.22.0-pre.1", default-features = false }
|
||||
|
||||
### For the main burn branch. ###
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a783fdef4997d18c6333600c049f898b0b4fc7c" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a783fdef4997d18c6333600c049f898b0b4fc7c" }
|
||||
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a783fdef4997d18c6333600c049f898b0b4fc7c" }
|
||||
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "b4a62bd93b677f83f95108abb5bb893be75729d7" }
|
||||
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b5983f200457271d7e336f0db903eda54834b924" }
|
||||
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b5983f200457271d7e336f0db903eda54834b924" }
|
||||
cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b5983f200457271d7e336f0db903eda54834b924" }
|
||||
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "deba7f003af37a231e38bebda293a89df338f467" }
|
||||
### For local development. ###
|
||||
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
|
||||
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
fn main() {
|
||||
println!("cargo::rustc-check-cfg=cfg(wgpu_metal)");
|
||||
println!("cargo::rustc-check-cfg=cfg(wgpu_vulkan)");
|
||||
println!("cargo::rustc-check-cfg=cfg(wgpu_webgpu)");
|
||||
println!("cargo::rustc-check-cfg=cfg(default_backend)");
|
||||
|
||||
// If you try to build with `--no-default-features`, we enable a cpu backend by default (Flex)
|
||||
@@ -11,47 +8,15 @@ fn main() {
|
||||
let ndarray = cfg!(feature = "ndarray");
|
||||
let tch = cfg!(feature = "tch");
|
||||
let cpu = cfg!(feature = "cpu");
|
||||
|
||||
let mut metal = cfg!(feature = "metal");
|
||||
let mut vulkan = cfg!(feature = "vulkan");
|
||||
let mut webgpu = cfg!(feature = "webgpu");
|
||||
let metal = cfg!(feature = "metal");
|
||||
let vulkan = cfg!(feature = "vulkan");
|
||||
let webgpu = cfg!(feature = "webgpu");
|
||||
let wgpu = cfg!(feature = "wgpu");
|
||||
|
||||
// Detect which single wgpu backend is enabled
|
||||
let wgpu_only = cfg!(all(
|
||||
feature = "wgpu",
|
||||
not(feature = "metal"),
|
||||
not(feature = "vulkan")
|
||||
));
|
||||
let enabled = [(metal, "metal"), (vulkan, "vulkan"), (webgpu, "webgpu")]
|
||||
.iter()
|
||||
.filter(|x| x.0)
|
||||
.map(|x| x.1)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// WGPU features are mutually exclusive, but we don't want to workspace to throw a compile error.
|
||||
// In workspace builds with multiple features, we emit a warning and fallback to WebGpu/Wgpu.
|
||||
if enabled.len() > 1 {
|
||||
webgpu = true;
|
||||
vulkan = false;
|
||||
metal = false;
|
||||
println!(
|
||||
"cargo:warning=Only one wgpu backend can be enabled at once. Detected: [{}]. Falling back to `wgpu`. For production, enable only one of: metal, vulkan, or webgpu.",
|
||||
enabled.join(", ")
|
||||
);
|
||||
}
|
||||
|
||||
let no_backend_enabled =
|
||||
!(cuda || flex || rocm || ndarray || tch || cpu || metal || vulkan || webgpu || wgpu_only);
|
||||
!(cuda || flex || rocm || ndarray || tch || cpu || metal || vulkan || webgpu || wgpu);
|
||||
|
||||
if metal {
|
||||
println!("cargo:rustc-cfg=wgpu_metal");
|
||||
}
|
||||
if vulkan {
|
||||
println!("cargo:rustc-cfg=wgpu_vulkan");
|
||||
}
|
||||
if webgpu || wgpu_only {
|
||||
println!("cargo:rustc-cfg=wgpu_webgpu");
|
||||
}
|
||||
if no_backend_enabled {
|
||||
println!("cargo:rustc-cfg=default_backend");
|
||||
}
|
||||
|
||||
@@ -92,8 +92,16 @@ impl Backend for Dispatch {
|
||||
DispatchDeviceId::Cpu => Cpu::device_count(backend_type_id),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchDeviceId::Cuda => Cuda::device_count(backend_type_id),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchDeviceId::Metal => Metal::device_count(backend_type_id),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchDeviceId::Rocm => Rocm::device_count(backend_type_id),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchDeviceId::Vulkan => Vulkan::device_count(backend_type_id),
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchDeviceId::Wgpu => Wgpu::device_count(backend_type_id),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchDeviceId::WebGpu => WebGpu::device_count(backend_type_id),
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchDeviceId::Flex => Flex::device_count(backend_type_id),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -150,10 +158,16 @@ impl AutodiffBackend for Dispatch {
|
||||
DispatchTensorKind::Cpu(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensorKind::Cuda(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchTensorKind::Metal(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensorKind::Rocm(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchTensorKind::Vulkan(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchTensorKind::Wgpu(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchTensorKind::WebGpu(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchTensorKind::Flex(tensor) => tensor.autodiff().backward(),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -187,16 +201,31 @@ impl AutodiffBackend for Dispatch {
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchTensorKind::Metal(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensorKind::Rocm(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchTensorKind::Vulkan(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchTensorKind::Wgpu(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchTensorKind::WebGpu(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad(grads)
|
||||
.map(|t| DispatchTensorKind::WebGpu(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchTensorKind::Flex(tensor) => tensor
|
||||
.as_autodiff()
|
||||
@@ -246,16 +275,31 @@ impl AutodiffBackend for Dispatch {
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchTensorKind::Metal(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensorKind::Rocm(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchTensorKind::Vulkan(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchTensorKind::Wgpu(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchTensorKind::WebGpu(tensor) => tensor
|
||||
.as_autodiff()
|
||||
.grad_remove(grads)
|
||||
.map(|t| DispatchTensorKind::WebGpu(crate::BackendTensor::Float(t))),
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchTensorKind::Flex(tensor) => tensor
|
||||
.as_autodiff()
|
||||
@@ -309,14 +353,26 @@ impl AutodiffBackend for Dispatch {
|
||||
(DispatchTensorKind::Cuda(tensor), DispatchTensorKind::Cuda(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
(DispatchTensorKind::Metal(tensor), DispatchTensorKind::Metal(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "rocm")]
|
||||
(DispatchTensorKind::Rocm(tensor), DispatchTensorKind::Rocm(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "vulkan")]
|
||||
(DispatchTensorKind::Vulkan(tensor), DispatchTensorKind::Vulkan(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "wgpu")]
|
||||
(DispatchTensorKind::Wgpu(tensor), DispatchTensorKind::Wgpu(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "webgpu")]
|
||||
(DispatchTensorKind::WebGpu(tensor), DispatchTensorKind::WebGpu(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
}
|
||||
#[cfg(feature = "flex")]
|
||||
(DispatchTensorKind::Flex(tensor), DispatchTensorKind::Flex(grad)) => {
|
||||
tensor.as_autodiff().grad_replace(grads, grad.float())
|
||||
@@ -356,14 +412,26 @@ impl AutodiffBackend for Dispatch {
|
||||
DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Cuda(
|
||||
crate::BackendTensor::Float(tensor.autodiff().primitive),
|
||||
),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Metal(
|
||||
crate::BackendTensor::Float(tensor.autodiff().primitive),
|
||||
),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Rocm(
|
||||
crate::BackendTensor::Float(tensor.autodiff().primitive),
|
||||
),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Vulkan(
|
||||
crate::BackendTensor::Float(tensor.autodiff().primitive),
|
||||
),
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Wgpu(
|
||||
crate::BackendTensor::Float(tensor.autodiff().primitive),
|
||||
),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchTensorKind::WebGpu(tensor) => DispatchTensorKind::WebGpu(
|
||||
crate::BackendTensor::Float(tensor.autodiff().primitive),
|
||||
),
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchTensorKind::Flex(tensor) => DispatchTensorKind::Flex(
|
||||
crate::BackendTensor::Float(tensor.autodiff().primitive),
|
||||
@@ -423,18 +491,36 @@ impl AutodiffBackend for Dispatch {
|
||||
crate::BackendTensor::Autodiff(Autodiff::<Cuda>::from_inner(tensor.float())),
|
||||
)))
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchTensorKind::Metal(tensor) => {
|
||||
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Metal(
|
||||
crate::BackendTensor::Autodiff(Autodiff::<Metal>::from_inner(tensor.float())),
|
||||
)))
|
||||
}
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensorKind::Rocm(tensor) => {
|
||||
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Rocm(
|
||||
crate::BackendTensor::Autodiff(Autodiff::<Rocm>::from_inner(tensor.float())),
|
||||
)))
|
||||
}
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchTensorKind::Vulkan(tensor) => {
|
||||
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Vulkan(
|
||||
crate::BackendTensor::Autodiff(Autodiff::<Vulkan>::from_inner(tensor.float())),
|
||||
)))
|
||||
}
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchTensorKind::Wgpu(tensor) => {
|
||||
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Wgpu(
|
||||
crate::BackendTensor::Autodiff(Autodiff::<Wgpu>::from_inner(tensor.float())),
|
||||
)))
|
||||
}
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchTensorKind::WebGpu(tensor) => {
|
||||
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::WebGpu(
|
||||
crate::BackendTensor::Autodiff(Autodiff::<WebGpu>::from_inner(tensor.float())),
|
||||
)))
|
||||
}
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchTensorKind::Flex(tensor) => {
|
||||
DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Flex(
|
||||
@@ -548,10 +634,16 @@ impl DispatchTensorKind {
|
||||
DispatchTensorKind::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchTensorKind::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchTensorKind::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchTensorKind::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchTensorKind::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchTensorKind::Wgpu(tensor) => DispatchDevice::Wgpu(tensor.device()),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchTensorKind::WebGpu(tensor) => DispatchDevice::WebGpu(tensor.device()),
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchTensorKind::Flex(tensor) => DispatchDevice::Flex(tensor.device()),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -597,13 +689,25 @@ impl Dispatch {
|
||||
DispatchDeviceId::Cuda => (0..Cuda::device_count(0))
|
||||
.map(|i| CudaDevice::new(i).into())
|
||||
.collect(),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchDeviceId::Metal => (0..Metal::device_count(0))
|
||||
.map(|i| DispatchDevice::Metal(WgpuDevice::DiscreteGpu(i)))
|
||||
.collect(),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchDeviceId::Rocm => (0..Rocm::device_count(0))
|
||||
.map(|i| RocmDevice::new(i).into())
|
||||
.collect(),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchDeviceId::Vulkan => (0..Vulkan::device_count(0))
|
||||
.map(|i| DispatchDevice::Vulkan(WgpuDevice::DiscreteGpu(i)))
|
||||
.collect(),
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchDeviceId::Wgpu => (0..Wgpu::device_count(0))
|
||||
.map(|i| WgpuDevice::DiscreteGpu(i).into())
|
||||
.map(|i| DispatchDevice::Wgpu(WgpuDevice::DiscreteGpu(i)))
|
||||
.collect(),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchDeviceId::WebGpu => (0..WebGpu::device_count(0))
|
||||
.map(|i| DispatchDevice::WebGpu(WgpuDevice::DiscreteGpu(i)))
|
||||
.collect(),
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchDeviceId::Flex => vec![FlexDevice.into()],
|
||||
|
||||
@@ -30,25 +30,26 @@ pub enum DispatchDevice {
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda(CudaDevice),
|
||||
|
||||
/// The [Metal backend](Metal) device (via WGPU runtime).
|
||||
#[cfg(feature = "metal")]
|
||||
Metal(WgpuDevice),
|
||||
|
||||
/// The [ROCm backend](Rocm) device.
|
||||
#[cfg(feature = "rocm")]
|
||||
Rocm(RocmDevice),
|
||||
|
||||
#[cfg_attr(
|
||||
wgpu_webgpu,
|
||||
doc = r#"The [WebGPU backend](Wgpu) device (via WGPU runtime)"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
wgpu_vulkan,
|
||||
doc = r#"The [Vulkan backend](Vulkan) device (via WGPU runtime)"#
|
||||
)]
|
||||
#[cfg_attr(
|
||||
wgpu_metal,
|
||||
doc = r#"The [Metal backend](Metal) device (via WGPU runtime)"#
|
||||
)]
|
||||
/// The [Vulkan backend](Vulkan) device.
|
||||
#[cfg(feature = "vulkan")]
|
||||
Vulkan(WgpuDevice),
|
||||
|
||||
/// The [Wgpu backend](Wgpu) device (via WGPU runtime with auto-selected compiler).
|
||||
#[cfg(feature = "wgpu")]
|
||||
Wgpu(WgpuDevice),
|
||||
|
||||
/// The [WebGPU backend](WebGpu) device (via WGPU runtime).
|
||||
#[cfg(feature = "webgpu")]
|
||||
WebGpu(WgpuDevice),
|
||||
|
||||
/// The [Flex backend](Flex) device (CPU-only).
|
||||
#[cfg(feature = "flex")]
|
||||
Flex(FlexDevice),
|
||||
@@ -145,14 +146,16 @@ impl core::fmt::Debug for DispatchDevice {
|
||||
Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
|
||||
#[cfg(feature = "metal")]
|
||||
Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
|
||||
#[cfg(wgpu_metal)]
|
||||
Self::Wgpu(device) => f.debug_tuple("Metal").field(device).finish(),
|
||||
#[cfg(wgpu_vulkan)]
|
||||
Self::Wgpu(device) => f.debug_tuple("Vulkan").field(device).finish(),
|
||||
#[cfg(wgpu_webgpu)]
|
||||
#[cfg(feature = "vulkan")]
|
||||
Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
|
||||
#[cfg(feature = "wgpu")]
|
||||
Self::Wgpu(device) => f.debug_tuple("Wgpu").field(device).finish(),
|
||||
#[cfg(feature = "webgpu")]
|
||||
Self::WebGpu(device) => f.debug_tuple("WebGpu").field(device).finish(),
|
||||
#[cfg(feature = "flex")]
|
||||
Self::Flex(device) => f.debug_tuple("Flex").field(device).finish(),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -189,8 +192,8 @@ impl Default for DispatchDevice {
|
||||
);
|
||||
}
|
||||
"metal" => {
|
||||
#[cfg(wgpu_metal)]
|
||||
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
|
||||
#[cfg(feature = "metal")]
|
||||
return Self::Metal(burn_wgpu::WgpuDevice::default());
|
||||
panic!(
|
||||
"BURN_DEVICE=metal requested, but the 'metal' feature is not enabled."
|
||||
);
|
||||
@@ -203,17 +206,24 @@ impl Default for DispatchDevice {
|
||||
);
|
||||
}
|
||||
"vulkan" => {
|
||||
#[cfg(wgpu_vulkan)]
|
||||
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
|
||||
#[cfg(feature = "vulkan")]
|
||||
return Self::Vulkan(burn_wgpu::WgpuDevice::default());
|
||||
panic!(
|
||||
"BURN_DEVICE=vulkan requested, but the 'vulkan' feature is not enabled."
|
||||
);
|
||||
}
|
||||
"webgpu" | "wgpu" => {
|
||||
#[cfg(wgpu_webgpu)]
|
||||
"webgpu" => {
|
||||
#[cfg(feature = "webgpu")]
|
||||
return Self::WebGpu(burn_wgpu::WgpuDevice::default());
|
||||
panic!(
|
||||
"BURN_DEVICE=webgpu requested, but the 'webgpu' feature is not enabled."
|
||||
);
|
||||
}
|
||||
"wgpu" => {
|
||||
#[cfg(feature = "wgpu")]
|
||||
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
|
||||
panic!(
|
||||
"BURN_DEVICE=wgpu requested, but the 'webgpu' or 'wgpu' feature is not enabled."
|
||||
"BURN_DEVICE=wgpu requested, but the 'wgpu' feature is not enabled."
|
||||
);
|
||||
}
|
||||
"cpu" => {
|
||||
@@ -255,12 +265,21 @@ impl Default for DispatchDevice {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Self::Cuda(CudaDevice::default());
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
return Self::Metal(burn_wgpu::WgpuDevice::default());
|
||||
|
||||
#[cfg(feature = "rocm")]
|
||||
return Self::Rocm(RocmDevice::default());
|
||||
|
||||
#[cfg(feature = "vulkan")]
|
||||
return Self::Vulkan(burn_wgpu::WgpuDevice::default());
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
return Self::Wgpu(burn_wgpu::WgpuDevice::default());
|
||||
|
||||
#[cfg(feature = "webgpu")]
|
||||
return Self::WebGpu(burn_wgpu::WgpuDevice::default());
|
||||
|
||||
#[cfg(feature = "cpu")]
|
||||
return Self::Cpu(CpuDevice);
|
||||
|
||||
@@ -296,10 +315,16 @@ impl PartialEq for DispatchDevice {
|
||||
(Self::Cpu(a), Self::Cpu(b)) => a == b,
|
||||
#[cfg(feature = "cuda")]
|
||||
(Self::Cuda(a), Self::Cuda(b)) => a == b,
|
||||
#[cfg(feature = "metal")]
|
||||
(Self::Metal(a), Self::Metal(b)) => a == b,
|
||||
#[cfg(feature = "rocm")]
|
||||
(Self::Rocm(a), Self::Rocm(b)) => a == b,
|
||||
#[cfg(feature = "vulkan")]
|
||||
(Self::Vulkan(a), Self::Vulkan(b)) => a == b,
|
||||
#[cfg(feature = "wgpu")]
|
||||
(Self::Wgpu(a), Self::Wgpu(b)) => a == b,
|
||||
#[cfg(feature = "webgpu")]
|
||||
(Self::WebGpu(a), Self::WebGpu(b)) => a == b,
|
||||
#[cfg(feature = "flex")]
|
||||
(Self::Flex(a), Self::Flex(b)) => a == b,
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -350,10 +375,16 @@ impl DispatchDevice {
|
||||
Self::Cpu(_) => DispatchDeviceId::Cpu,
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(_) => DispatchDeviceId::Cuda,
|
||||
#[cfg(feature = "metal")]
|
||||
Self::Metal(_) => DispatchDeviceId::Metal,
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(_) => DispatchDeviceId::Rocm,
|
||||
#[cfg(feature = "vulkan")]
|
||||
Self::Vulkan(_) => DispatchDeviceId::Vulkan,
|
||||
#[cfg(feature = "wgpu")]
|
||||
Self::Wgpu(_) => DispatchDeviceId::Wgpu,
|
||||
#[cfg(feature = "webgpu")]
|
||||
Self::WebGpu(_) => DispatchDeviceId::WebGpu,
|
||||
#[cfg(feature = "flex")]
|
||||
Self::Flex(_) => DispatchDeviceId::Flex,
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -393,13 +424,15 @@ impl DispatchDevice {
|
||||
pub enum DispatchDeviceId {
|
||||
Cpu = 0,
|
||||
Cuda = 1,
|
||||
// webgpu, vulkan and device all point to WgpuDevice
|
||||
Wgpu = 2,
|
||||
Rocm = 3,
|
||||
Flex = 4,
|
||||
LibTorch = 5,
|
||||
NdArray = 6,
|
||||
Remote = 7,
|
||||
Metal = 7,
|
||||
Vulkan = 8,
|
||||
WebGpu = 9,
|
||||
Remote = 10,
|
||||
}
|
||||
|
||||
impl From<DispatchDeviceId> for u16 {
|
||||
@@ -427,8 +460,14 @@ impl TryFrom<u16> for DispatchDeviceId {
|
||||
5 => Ok(Self::LibTorch),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
6 => Ok(Self::NdArray),
|
||||
#[cfg(feature = "metal")]
|
||||
7 => Ok(Self::Metal),
|
||||
#[cfg(feature = "vulkan")]
|
||||
8 => Ok(Self::Vulkan),
|
||||
#[cfg(feature = "webgpu")]
|
||||
9 => Ok(Self::WebGpu),
|
||||
#[cfg(feature = "remote")]
|
||||
7 => Ok(Self::Remote),
|
||||
10 => Ok(Self::Remote),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
@@ -441,10 +480,16 @@ impl DeviceOps for DispatchDevice {
|
||||
Self::Cpu(device) => device.defaults(),
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(device) => device.defaults(),
|
||||
#[cfg(feature = "metal")]
|
||||
Self::Metal(device) => device.defaults(),
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(device) => device.defaults(),
|
||||
#[cfg(feature = "vulkan")]
|
||||
Self::Vulkan(device) => device.defaults(),
|
||||
#[cfg(feature = "wgpu")]
|
||||
Self::Wgpu(device) => device.defaults(),
|
||||
#[cfg(feature = "webgpu")]
|
||||
Self::WebGpu(device) => device.defaults(),
|
||||
#[cfg(feature = "flex")]
|
||||
Self::Flex(device) => device.defaults(),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -469,10 +514,16 @@ impl burn_backend::Device for DispatchDevice {
|
||||
DispatchDeviceId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchDeviceId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchDeviceId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchDeviceId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchDeviceId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchDeviceId::Wgpu => Self::Wgpu(WgpuDevice::from_id(device_id)),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchDeviceId::WebGpu => Self::WebGpu(WgpuDevice::from_id(device_id)),
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchDeviceId::Flex => Self::Flex(FlexDevice::from_id(device_id)),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -491,10 +542,16 @@ impl burn_backend::Device for DispatchDevice {
|
||||
Self::Cpu(device) => device.to_id(),
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(device) => device.to_id(),
|
||||
#[cfg(feature = "metal")]
|
||||
Self::Metal(device) => device.to_id(),
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(device) => device.to_id(),
|
||||
#[cfg(feature = "vulkan")]
|
||||
Self::Vulkan(device) => device.to_id(),
|
||||
#[cfg(feature = "wgpu")]
|
||||
Self::Wgpu(device) => device.to_id(),
|
||||
#[cfg(feature = "webgpu")]
|
||||
Self::WebGpu(device) => device.to_id(),
|
||||
#[cfg(feature = "flex")]
|
||||
Self::Flex(device) => device.to_id(),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -525,13 +582,6 @@ impl From<CudaDevice> for DispatchDevice {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(wgpu_metal)]
|
||||
impl From<WgpuDevice> for DispatchDevice {
|
||||
fn from(device: WgpuDevice) -> Self {
|
||||
DispatchDevice::Wgpu(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rocm")]
|
||||
impl From<RocmDevice> for DispatchDevice {
|
||||
fn from(device: RocmDevice) -> Self {
|
||||
@@ -539,14 +589,9 @@ impl From<RocmDevice> for DispatchDevice {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(wgpu_vulkan)]
|
||||
impl From<WgpuDevice> for DispatchDevice {
|
||||
fn from(device: WgpuDevice) -> Self {
|
||||
DispatchDevice::Wgpu(device)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(wgpu_webgpu)]
|
||||
// A bare `WgpuDevice` maps to the auto-compiler [`DispatchDevice::Wgpu`] variant. To target a
|
||||
// specific wgpu specialization (Metal, Vulkan, WebGpu) construct the variant explicitly.
|
||||
#[cfg(feature = "wgpu")]
|
||||
impl From<WgpuDevice> for DispatchDevice {
|
||||
fn from(device: WgpuDevice) -> Self {
|
||||
DispatchDevice::Wgpu(device)
|
||||
|
||||
@@ -22,21 +22,9 @@
|
||||
//! | `LibTorch` | `tch` | Libtorch backend via `tch` |
|
||||
//! | `Autodiff` | `autodiff` | Autodiff-enabled backend (used in combination with any of the backends above) |
|
||||
//!
|
||||
//! **Note:** WGPU-based backends (`metal`, `vulkan`, `webgpu`) are mutually exclusive.
|
||||
//! All other backends can be combined freely.
|
||||
//!
|
||||
//! ## WGPU Backend Exclusivity
|
||||
//!
|
||||
//! The WGPU-based backends (`metal`, `vulkan`, `webgpu`) are **mutually exclusive** due to
|
||||
//! the current automatic compile, which can only select one target at a time.
|
||||
//!
|
||||
//! Enable only **one** of these features in your `Cargo.toml`:
|
||||
//! - `metal`
|
||||
//! - `vulkan`
|
||||
//! - `webgpu`
|
||||
//!
|
||||
//! If multiple WGPU features are enabled, the build script will emit a warning and enable `webgpu` only
|
||||
//! to prevent unintended behavior.
|
||||
//! **Note:** All backends, including the WGPU-based ones (`wgpu`, `metal`, `vulkan`, `webgpu`),
|
||||
//! can be combined freely. Each enabled wgpu backend appears as its own
|
||||
//! [`DispatchDevice`] variant.
|
||||
|
||||
#[macro_use]
|
||||
mod macros;
|
||||
@@ -80,11 +68,11 @@ pub mod backends {
|
||||
pub use burn_rocm::Rocm;
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub use burn_wgpu as wgpu;
|
||||
#[cfg(wgpu_metal)]
|
||||
#[cfg(feature = "metal")]
|
||||
pub use burn_wgpu::Metal;
|
||||
#[cfg(wgpu_vulkan)]
|
||||
#[cfg(feature = "vulkan")]
|
||||
pub use burn_wgpu::Vulkan;
|
||||
#[cfg(all(wgpu_webgpu, feature = "webgpu"))]
|
||||
#[cfg(feature = "webgpu")]
|
||||
pub use burn_wgpu::WebGpu;
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub use burn_wgpu::Wgpu;
|
||||
|
||||
@@ -1,3 +1,44 @@
|
||||
/// Resolves a backend identifier to its concrete `Backend`-impl type.
|
||||
///
|
||||
/// Now that element type parameters have been removed project-wide, this is a simple
|
||||
/// passthrough — kept so the dispatch macros stay agnostic to future signature changes.
|
||||
#[allow(unused_macros)]
|
||||
macro_rules! inst {
|
||||
(Cpu) => {
|
||||
$crate::backends::Cpu
|
||||
};
|
||||
(Cuda) => {
|
||||
$crate::backends::Cuda
|
||||
};
|
||||
(Metal) => {
|
||||
$crate::backends::Metal
|
||||
};
|
||||
(Rocm) => {
|
||||
$crate::backends::Rocm
|
||||
};
|
||||
(Vulkan) => {
|
||||
$crate::backends::Vulkan
|
||||
};
|
||||
(Wgpu) => {
|
||||
$crate::backends::Wgpu
|
||||
};
|
||||
(WebGpu) => {
|
||||
$crate::backends::WebGpu
|
||||
};
|
||||
(Flex) => {
|
||||
$crate::backends::Flex
|
||||
};
|
||||
(NdArray) => {
|
||||
$crate::backends::NdArray
|
||||
};
|
||||
(LibTorch) => {
|
||||
$crate::backends::LibTorch
|
||||
};
|
||||
(Remote) => {
|
||||
$crate::backends::Remote
|
||||
};
|
||||
}
|
||||
|
||||
/// Supplies a list of all supported backends and their corresponding feature flags
|
||||
/// to a callback macro. This centralizes the backend registry.
|
||||
macro_rules! backend_list {
|
||||
@@ -6,8 +47,11 @@ macro_rules! backend_list {
|
||||
$($extra)*;
|
||||
[Cpu, feature = "cpu"],
|
||||
[Cuda, feature = "cuda"],
|
||||
[Metal, feature = "metal"],
|
||||
[Rocm, feature = "rocm"],
|
||||
[Vulkan, feature = "vulkan"],
|
||||
[Wgpu, feature = "wgpu"],
|
||||
[WebGpu, feature = "webgpu"],
|
||||
[Flex, feature = "flex"],
|
||||
[NdArray, any(feature = "ndarray", default_backend)],
|
||||
[LibTorch, feature = "tch"],
|
||||
@@ -21,14 +65,17 @@ macro_rules! backend_matrix {
|
||||
($callback:ident, $($extra:tt)*) => {
|
||||
$callback! {
|
||||
$($extra)*;
|
||||
[Cpu, feature = "cpu"] => [[Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Cuda, feature = "cuda"] => [[Cpu, feature = "cpu"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Rocm, feature = "rocm"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Wgpu, feature = "wgpu"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Flex, feature = "flex"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[NdArray, any(feature = "ndarray", default_backend)] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[LibTorch, feature = "tch"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [Remote, feature = "remote"]];
|
||||
[Remote, feature = "remote"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"]]
|
||||
[Cpu, feature = "cpu"] => [[Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Cuda, feature = "cuda"] => [[Cpu, feature = "cpu"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Metal, feature = "metal"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Rocm, feature = "rocm"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Vulkan, feature = "vulkan"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Wgpu, feature = "wgpu"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[WebGpu, feature = "webgpu"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[Flex, feature = "flex"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[NdArray, any(feature = "ndarray", default_backend)] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [LibTorch, feature = "tch"], [Remote, feature = "remote"]];
|
||||
[LibTorch, feature = "tch"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [Remote, feature = "remote"]];
|
||||
[Remote, feature = "remote"] => [[Cpu, feature = "cpu"], [Cuda, feature = "cuda"], [Metal, feature = "metal"], [Rocm, feature = "rocm"], [Vulkan, feature = "vulkan"], [Wgpu, feature = "wgpu"], [WebGpu, feature = "webgpu"], [Flex, feature = "flex"], [NdArray, any(feature = "ndarray", default_backend)], [LibTorch, feature = "tch"]]
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -19,10 +19,22 @@ pub fn start_websocket(device: DispatchDevice, port: u16) {
|
||||
DispatchDevice::Cpu(device) => burn_remote::server::start_websocket::<Cpu>(device, port),
|
||||
#[cfg(feature = "cuda")]
|
||||
DispatchDevice::Cuda(device) => burn_remote::server::start_websocket::<Cuda>(device, port),
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchDevice::Metal(device) => {
|
||||
burn_remote::server::start_websocket::<Metal>(device, port)
|
||||
}
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchDevice::Rocm(device) => burn_remote::server::start_websocket::<Rocm>(device, port),
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchDevice::Vulkan(device) => {
|
||||
burn_remote::server::start_websocket::<Vulkan>(device, port)
|
||||
}
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchDevice::Wgpu(device) => burn_remote::server::start_websocket::<Wgpu>(device, port),
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchDevice::WebGpu(device) => {
|
||||
burn_remote::server::start_websocket::<WebGpu>(device, port)
|
||||
}
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchDevice::Flex(device) => burn_remote::server::start_websocket::<Flex>(device, port),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -57,14 +69,26 @@ pub async fn start_websocket_async(device: DispatchDevice, port: u16) {
|
||||
DispatchDevice::Cuda(device) => {
|
||||
burn_remote::server::start_websocket_async::<Cuda>(device, port).await
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
DispatchDevice::Metal(device) => {
|
||||
burn_remote::server::start_websocket_async::<Metal>(device, port).await
|
||||
}
|
||||
#[cfg(feature = "rocm")]
|
||||
DispatchDevice::Rocm(device) => {
|
||||
burn_remote::server::start_websocket_async::<Rocm>(device, port).await
|
||||
}
|
||||
#[cfg(feature = "vulkan")]
|
||||
DispatchDevice::Vulkan(device) => {
|
||||
burn_remote::server::start_websocket_async::<Vulkan>(device, port).await
|
||||
}
|
||||
#[cfg(feature = "wgpu")]
|
||||
DispatchDevice::Wgpu(device) => {
|
||||
burn_remote::server::start_websocket_async::<Wgpu>(device, port).await
|
||||
}
|
||||
#[cfg(feature = "webgpu")]
|
||||
DispatchDevice::WebGpu(device) => {
|
||||
burn_remote::server::start_websocket_async::<WebGpu>(device, port).await
|
||||
}
|
||||
#[cfg(feature = "flex")]
|
||||
DispatchDevice::Flex(device) => {
|
||||
burn_remote::server::start_websocket_async::<Flex>(device, port).await
|
||||
|
||||
@@ -191,16 +191,26 @@ pub enum DispatchTensorKind {
|
||||
#[cfg(feature = "cuda")]
|
||||
Cuda(BackendTensor<Cuda>),
|
||||
|
||||
/// The [Metal backend](Metal) tensor.
|
||||
#[cfg(feature = "metal")]
|
||||
Metal(BackendTensor<Metal>),
|
||||
|
||||
/// The [ROCm backend](Rocm) tensor.
|
||||
#[cfg(feature = "rocm")]
|
||||
Rocm(BackendTensor<Rocm>),
|
||||
|
||||
#[cfg_attr(wgpu_webgpu, doc = r#"The [WebGPU backend](Wgpu) tensor"#)]
|
||||
#[cfg_attr(wgpu_vulkan, doc = r#"The [Vulkan backend](Vulkan) tensor"#)]
|
||||
#[cfg_attr(wgpu_metal, doc = r#"The [Metal backend](Metal) tensor"#)]
|
||||
/// The [Vulkan backend](Vulkan) tensor.
|
||||
#[cfg(feature = "vulkan")]
|
||||
Vulkan(BackendTensor<Vulkan>),
|
||||
|
||||
/// The [Wgpu backend](Wgpu) tensor.
|
||||
#[cfg(feature = "wgpu")]
|
||||
Wgpu(BackendTensor<Wgpu>),
|
||||
|
||||
/// The [WebGPU backend](Wgpu) tensor.
|
||||
#[cfg(feature = "webgpu")]
|
||||
WebGpu(BackendTensor<WebGpu>),
|
||||
|
||||
/// The [Flex backend](Flex) tensor.
|
||||
#[cfg(feature = "flex")]
|
||||
Flex(BackendTensor<Flex>),
|
||||
@@ -229,10 +239,16 @@ impl TensorMetadata for DispatchTensorKind {
|
||||
Self::Cpu(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "metal")]
|
||||
Self::Metal(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "vulkan")]
|
||||
Self::Vulkan(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "wgpu")]
|
||||
Self::Wgpu(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "webgpu")]
|
||||
Self::WebGpu(tensor) => tensor.dtype(),
|
||||
#[cfg(feature = "flex")]
|
||||
Self::Flex(tensor) => tensor.dtype(),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
@@ -252,10 +268,16 @@ impl TensorMetadata for DispatchTensorKind {
|
||||
Self::Cpu(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "cuda")]
|
||||
Self::Cuda(tensor) => tensor.shape(),
|
||||
#[cfg(wgpu_metal)]
|
||||
#[cfg(feature = "metal")]
|
||||
Self::Metal(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "rocm")]
|
||||
Self::Rocm(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "vulkan")]
|
||||
Self::Vulkan(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "wgpu")]
|
||||
Self::Wgpu(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "webgpu")]
|
||||
Self::WebGpu(tensor) => tensor.shape(),
|
||||
#[cfg(feature = "flex")]
|
||||
Self::Flex(tensor) => tensor.shape(),
|
||||
#[cfg(any(feature = "ndarray", default_backend))]
|
||||
|
||||
@@ -41,10 +41,10 @@ rocm = ["burn-dispatch/rocm"]
|
||||
flex = ["burn-dispatch/flex"]
|
||||
ndarray = ["burn-dispatch/ndarray"]
|
||||
tch = ["burn-dispatch/tch"]
|
||||
vulkan = ["burn-dispatch/vulkan"]
|
||||
webgpu = ["burn-dispatch/webgpu"]
|
||||
vulkan = ["wgpu", "burn-dispatch/vulkan"]
|
||||
webgpu = ["wgpu", "burn-dispatch/webgpu"]
|
||||
wgpu = ["burn-dispatch/wgpu"]
|
||||
metal = ["burn-dispatch/metal"]
|
||||
metal = ["wgpu", "burn-dispatch/metal"]
|
||||
cpu = ["burn-dispatch/cpu"]
|
||||
autodiff = ["burn-dispatch/autodiff"]
|
||||
remote = ["std", "burn-dispatch/remote"]
|
||||
|
||||
@@ -350,40 +350,34 @@ impl Device {
|
||||
/// equivalent calls (they exist only to make the intended backend
|
||||
/// explicit at the call site — the underlying adapter is still picked by
|
||||
/// enabled Cargo features).
|
||||
#[cfg(any(
|
||||
feature = "wgpu",
|
||||
feature = "vulkan",
|
||||
feature = "metal",
|
||||
feature = "webgpu"
|
||||
))]
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub fn wgpu(device_kind: DeviceKind) -> Self {
|
||||
Self::new(wgpu_device(device_kind))
|
||||
Self::new(DispatchDevice::Wgpu(wgpu_device(device_kind)))
|
||||
}
|
||||
|
||||
/// Vulkan-backed WGPU device, selected via [`DeviceKind`].
|
||||
///
|
||||
/// Equivalent to [`Device::wgpu`]; provided so the caller's choice of
|
||||
/// graphics API reads clearly at the call site. The actual adapter is
|
||||
/// still determined by enabled Cargo features.
|
||||
/// Pins the wgpu shader compiler to SPIR-V at compile time, avoiding
|
||||
/// the runtime [`AutoCompiler`](burn_dispatch::backends::wgpu::AutoCompiler) dispatch.
|
||||
#[cfg(feature = "vulkan")]
|
||||
pub fn vulkan(device_kind: DeviceKind) -> Self {
|
||||
Self::new(wgpu_device(device_kind))
|
||||
Self::new(DispatchDevice::Vulkan(wgpu_device(device_kind)))
|
||||
}
|
||||
|
||||
/// Metal-backed WGPU device, selected via [`DeviceKind`].
|
||||
///
|
||||
/// See [`Device::vulkan`] — same shape, different feature gate.
|
||||
/// Pins the wgpu shader compiler to MSL at compile time.
|
||||
#[cfg(feature = "metal")]
|
||||
pub fn metal(device_kind: DeviceKind) -> Self {
|
||||
Self::new(wgpu_device(device_kind))
|
||||
Self::new(DispatchDevice::Metal(wgpu_device(device_kind)))
|
||||
}
|
||||
|
||||
/// WebGPU-backed device, selected via [`DeviceKind`].
|
||||
///
|
||||
/// See [`Device::vulkan`] — same shape, different feature gate.
|
||||
/// Pins the wgpu shader compiler to WGSL at compile time.
|
||||
#[cfg(feature = "webgpu")]
|
||||
pub fn webgpu(device_kind: DeviceKind) -> Self {
|
||||
Self::new(wgpu_device(device_kind))
|
||||
Self::new(DispatchDevice::WebGpu(wgpu_device(device_kind)))
|
||||
}
|
||||
|
||||
/// Enables autodiff on this device.
|
||||
@@ -585,31 +579,34 @@ impl Device {
|
||||
#[allow(clippy::never_loop)] // at least one backend is expected to be enabled.
|
||||
for device_type in filter.into() {
|
||||
#[allow(unused)]
|
||||
let type_id = match device_type {
|
||||
let type_ids: &[DispatchDeviceId] = match device_type {
|
||||
#[cfg(feature = "cpu")]
|
||||
DeviceType::Cpu => DispatchDeviceId::Cpu,
|
||||
DeviceType::Cpu => &[DispatchDeviceId::Cpu],
|
||||
#[cfg(feature = "cuda")]
|
||||
DeviceType::Cuda => DispatchDeviceId::Cuda,
|
||||
DeviceType::Cuda => &[DispatchDeviceId::Cuda],
|
||||
#[cfg(feature = "rocm")]
|
||||
DeviceType::Rocm => DispatchDeviceId::Rocm,
|
||||
#[cfg(any(
|
||||
feature = "wgpu",
|
||||
feature = "metal",
|
||||
feature = "vulkan",
|
||||
feature = "webgpu"
|
||||
))]
|
||||
DeviceType::Wgpu => DispatchDeviceId::Wgpu,
|
||||
DeviceType::Rocm => &[DispatchDeviceId::Rocm],
|
||||
#[cfg(feature = "wgpu")]
|
||||
DeviceType::Wgpu => &[DispatchDeviceId::Wgpu],
|
||||
#[cfg(feature = "metal")]
|
||||
DeviceType::Metal => &[DispatchDeviceId::Metal],
|
||||
#[cfg(feature = "vulkan")]
|
||||
DeviceType::Vulkan => &[DispatchDeviceId::Vulkan],
|
||||
#[cfg(feature = "webgpu")]
|
||||
DeviceType::WebGpu => &[DispatchDeviceId::WebGpu],
|
||||
#[cfg(feature = "flex")]
|
||||
DeviceType::Flex => DispatchDeviceId::Flex,
|
||||
DeviceType::Flex => &[DispatchDeviceId::Flex],
|
||||
#[cfg(feature = "ndarray")]
|
||||
DeviceType::NdArray => DispatchDeviceId::NdArray,
|
||||
DeviceType::NdArray => &[DispatchDeviceId::NdArray],
|
||||
#[cfg(feature = "tch")]
|
||||
DeviceType::LibTorch => DispatchDeviceId::LibTorch,
|
||||
DeviceType::LibTorch => &[DispatchDeviceId::LibTorch],
|
||||
};
|
||||
|
||||
#[allow(unreachable_code)] // need to have one backend enabled, so it is reachable
|
||||
for device in Dispatch::enumerate(type_id) {
|
||||
devices.push(Device::new(device))
|
||||
for type_id in type_ids {
|
||||
for device in Dispatch::enumerate(*type_id) {
|
||||
devices.push(Device::new(device))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -621,12 +618,7 @@ impl Device {
|
||||
///
|
||||
/// Shared by [`Device::wgpu`], [`Device::vulkan`], [`Device::metal`], and
|
||||
/// [`Device::webgpu`], which differ only in which Cargo feature gates them.
|
||||
#[cfg(any(
|
||||
feature = "wgpu",
|
||||
feature = "vulkan",
|
||||
feature = "metal",
|
||||
feature = "webgpu"
|
||||
))]
|
||||
#[cfg(feature = "wgpu")]
|
||||
fn wgpu_device(device_kind: DeviceKind) -> burn_dispatch::devices::WgpuDevice {
|
||||
use burn_dispatch::devices::WgpuDevice;
|
||||
match device_kind {
|
||||
@@ -653,13 +645,14 @@ pub enum DeviceType {
|
||||
Cuda,
|
||||
#[cfg(feature = "rocm")]
|
||||
Rocm,
|
||||
#[cfg(any(
|
||||
feature = "wgpu",
|
||||
feature = "metal",
|
||||
feature = "vulkan",
|
||||
feature = "webgpu"
|
||||
))]
|
||||
#[cfg(feature = "wgpu")]
|
||||
Wgpu,
|
||||
#[cfg(feature = "metal")]
|
||||
Metal,
|
||||
#[cfg(feature = "vulkan")]
|
||||
Vulkan,
|
||||
#[cfg(feature = "webgpu")]
|
||||
WebGpu,
|
||||
#[cfg(feature = "flex")]
|
||||
Flex,
|
||||
#[cfg(feature = "ndarray")]
|
||||
|
||||
@@ -14,6 +14,13 @@ pub use burn_cubecl::{CubeBackend, tensor::CubeTensor};
|
||||
pub use cubecl::CubeDim;
|
||||
pub use cubecl::flex32;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use cubecl::wgpu::MslCompiler;
|
||||
#[cfg(feature = "vulkan")]
|
||||
use cubecl::wgpu::SpirvCompiler;
|
||||
#[cfg(feature = "webgpu")]
|
||||
use cubecl::wgpu::WgslCompiler;
|
||||
|
||||
pub use cubecl::wgpu::{
|
||||
AutoCompiler, MemoryConfiguration, RuntimeOptions, WgpuDevice, WgpuResource, WgpuRuntime,
|
||||
WgpuSetup, WgpuStorage, init_device, init_setup, init_setup_async,
|
||||
@@ -23,12 +30,12 @@ pub mod graphics {
|
||||
pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};
|
||||
}
|
||||
|
||||
#[cfg(feature = "webgpu")]
|
||||
pub use cubecl::wgpu::WgslCompiler;
|
||||
#[cfg(feature = "vulkan")]
|
||||
pub use cubecl::wgpu::vulkan::VkSpirvCompiler;
|
||||
|
||||
#[cfg(feature = "fusion")]
|
||||
type WgpuInner<C> = burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime<C>>>;
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
type WgpuInner<C> = CubeBackend<cubecl::wgpu::WgpuRuntime<C>>;
|
||||
|
||||
/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
|
||||
///
|
||||
/// This backend can target multiple graphics APIs, including:
|
||||
@@ -38,39 +45,11 @@ pub use cubecl::wgpu::vulkan::VkSpirvCompiler;
|
||||
/// - [Metal][crate::graphics::Metal] on Apple hardware.
|
||||
/// - [WebGPU](crate::graphics::WebGpu) on supported browsers and `wasm` runtimes.
|
||||
///
|
||||
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
|
||||
/// you have to manually initialize the runtime. For example:
|
||||
///
|
||||
/// ```rust, ignore
|
||||
/// fn custom_init() {
|
||||
/// let device = Default::default();
|
||||
/// burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
|
||||
/// &device,
|
||||
/// Default::default(),
|
||||
/// );
|
||||
/// }
|
||||
/// ```
|
||||
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
|
||||
/// It's also possible to use an existing wgpu device, by using `init_device`.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This version of the wgpu backend uses [burn_fusion] to compile and optimize streams of tensor
|
||||
/// operations for improved performance.
|
||||
///
|
||||
/// You can disable the `fusion` feature flag to remove that functionality, which might be
|
||||
/// necessary on `wasm` for now.
|
||||
pub type Wgpu = burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime>>;
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
|
||||
///
|
||||
/// This backend can target multiple graphics APIs, including:
|
||||
/// - [Vulkan] on Linux, Windows, and Android.
|
||||
/// - [OpenGL](crate::OpenGl) on Linux, Windows, and Android.
|
||||
/// - [DirectX 12](crate::Dx12) on Windows.
|
||||
/// - [Metal] on Apple hardware.
|
||||
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
|
||||
/// The selected graphics API is chosen automatically at runtime, and the appropriate shader
|
||||
/// compiler (WGSL, SPIR-V or MSL) is dispatched via [`AutoCompiler`]. When the target API is
|
||||
/// known ahead of time, prefer the dedicated `Vulkan`, `WebGpu` or `Metal` backend aliases
|
||||
/// (enabled by their respective Cargo features), which lock the compiler at compile time and
|
||||
/// avoid the runtime dispatch.
|
||||
///
|
||||
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
|
||||
/// you have to manually initialize the runtime. For example:
|
||||
@@ -89,24 +68,34 @@ pub type Wgpu = burn_fusion::Fusion<CubeBackend<cubecl::wgpu::WgpuRuntime>>;
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This version of the wgpu backend doesn't use [burn_fusion] to compile and optimize streams of tensor
|
||||
/// operations.
|
||||
///
|
||||
/// You can enable the `fusion` feature flag to add that functionality, which might improve
|
||||
/// performance.
|
||||
pub type Wgpu = CubeBackend<cubecl::wgpu::WgpuRuntime>;
|
||||
/// When the `fusion` feature flag is enabled (the default), this backend uses [burn_fusion] to
|
||||
/// compile and optimize streams of tensor operations for improved performance. You can disable
|
||||
/// the `fusion` feature flag to remove that functionality, which might be necessary on `wasm`
|
||||
/// for now.
|
||||
pub type Wgpu = WgpuInner<AutoCompiler>;
|
||||
|
||||
#[cfg(feature = "vulkan")]
|
||||
/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V.
|
||||
pub type Vulkan = Wgpu;
|
||||
///
|
||||
/// This is a specialization of [`Wgpu`] that pins the shader compiler to SPIR-V at compile time,
|
||||
/// removing the runtime [`AutoCompiler`] dispatch. Enable the `vulkan` feature to use it.
|
||||
/// Multiple wgpu backend aliases (`Vulkan`, `WebGpu`, `Metal`) can be enabled simultaneously
|
||||
/// since each is a distinct type parameterized by its own compiler.
|
||||
#[cfg(feature = "vulkan")]
|
||||
pub type Vulkan = WgpuInner<SpirvCompiler>;
|
||||
|
||||
#[cfg(feature = "webgpu")]
|
||||
/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.
|
||||
pub type WebGpu = Wgpu;
|
||||
///
|
||||
/// This is a specialization of [`Wgpu`] that pins the shader compiler to WGSL at compile time,
|
||||
/// removing the runtime [`AutoCompiler`] dispatch. Enable the `webgpu` feature to use it.
|
||||
#[cfg(feature = "webgpu")]
|
||||
pub type WebGpu = WgpuInner<WgslCompiler>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
/// Tensor backend that leverages the Metal graphics API to execute GPU compute shaders compiled to MSL.
|
||||
pub type Metal = Wgpu;
|
||||
///
|
||||
/// This is a specialization of [`Wgpu`] that pins the shader compiler to MSL at compile time,
|
||||
/// removing the runtime [`AutoCompiler`] dispatch. Enable the `metal` feature to use it.
|
||||
#[cfg(feature = "metal")]
|
||||
pub type Metal = WgpuInner<MslCompiler>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
Reference in New Issue
Block a user