Register cuda_lite only under "cuda_lite", not "cuda" or "gpu"

Avoids confusion with cuda_heavy. Auto-detection now returns
"cuda_lite" for CUDA tensors. Test scripts updated to match.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Tucker Morgan
2026-04-15 18:23:33 +00:00
parent e6d13a3979
commit 52b2a45c62
4 changed files with 5 additions and 7 deletions

View File

@@ -15,7 +15,7 @@ pub struct CudaLiteDynBackend {
}
impl DynBackend for CudaLiteDynBackend {
fn name(&self) -> &str { "cuda" }
fn name(&self) -> &str { "cuda_lite" }
fn device_type(&self) -> &str { "cuda" }
fn set_data_bytes(&mut self, node: NodeIndex, bytes: Vec<u8>, _dtype: DType) {
@@ -58,9 +58,7 @@ fn cuda_lite_factory(graph: &mut Graph, args: BackendCompileArgs) -> Result<Box<
)
}
/// Register under `"cuda_lite"`, `"cuda"`, and `"gpu"`.
/// Register under `"cuda_lite"`.
pub fn register() {
register_backend("cuda_lite", cuda_lite_factory);
register_backend("cuda", cuda_lite_factory);
register_backend("gpu", cuda_lite_factory);
}

View File

@@ -28,7 +28,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
echo ""
echo "--- 2a: CUDA ---"
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest $CUDA_TESTS -m "not slow" -v
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda_lite uv run pytest $CUDA_TESTS -m "not slow" -v
echo ""
echo "=========================================="

View File

@@ -14,7 +14,7 @@ uv run maturin develop --manifest-path rust/Cargo.toml --features cuda -r
# Run pytest with CUDA backend
echo "Step 3: Running pytest with CUDA backend..."
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
RUST_BACKTRACE=1 LUMINAL_BACKEND=cuda_lite uv run pytest tests/test_llama3.py tests/test_hlir_ops.py tests/test_unary.py -v
echo ""
echo "=== Tests Complete ==="

View File

@@ -16,7 +16,7 @@ def _detect_backend(example_inputs):
if env_backend:
return env_backend
device = example_inputs[0].device if example_inputs else torch.device("cpu")
return "cuda" if device.type == "cuda" else "native"
return "cuda_lite" if device.type == "cuda" else "native"
def _collect_weight_pointers(weights, backend):