Merge remote-tracking branch 'origin/main' into worktree-respectingdatatypes_removingonnx

# Conflicts:
#	crates/luminal_python/rust/src/ops_parse/convolution.rs
#	crates/luminal_python/tests/test_hlir_ops.py
This commit is contained in:
Tucker Morgan
2026-04-08 20:32:25 +00:00
10 changed files with 727 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ name: Modal Examples
on:
push:
branches: ["main"]
pull_request:
pull_request_target:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
@@ -13,7 +13,7 @@ jobs:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
|| (github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: "${{ matrix.example }} (Modal ${{ matrix.gpu.type }})"
runs-on: ubuntu-latest
@@ -30,6 +30,8 @@ jobs:
steps:
- uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Set up Python
uses: actions/setup-python@v5
with:

View File

@@ -3,7 +3,7 @@ name: Test CUDA
on:
push:
branches: ["main"]
pull_request:
pull_request_target:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
@@ -13,7 +13,7 @@ jobs:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
|| (github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Cuda Unit Tests
runs-on: ubuntu-latest
@@ -22,6 +22,8 @@ jobs:
steps:
- uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Set up Python
uses: actions/setup-python@v5
with:

View File

@@ -3,7 +3,7 @@ name: Test Python CUDA
on:
push:
branches: ["main"]
pull_request:
pull_request_target:
branches: ["main"]
types: [labeled, synchronize]
workflow_dispatch:
@@ -13,7 +13,7 @@ jobs:
if: >-
github.event_name == 'push'
|| github.event_name == 'workflow_dispatch'
|| (github.event_name == 'pull_request'
|| (github.event_name == 'pull_request_target'
&& contains(github.event.pull_request.labels.*.name, 'modal-ready'))
name: Python CUDA Tests
runs-on: ubuntu-latest
@@ -25,6 +25,8 @@ jobs:
steps:
- uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha || github.sha }}
- name: Set up Python
uses: actions/setup-python@v5
with:

View File

@@ -1,7 +1,6 @@
import modal
import subprocess
import os
import sys
gpu_type = os.environ.get("GPU_TYPE", "T4")
CUDARC_CUDA_VERSION = "12080"
@@ -46,8 +45,10 @@ def run_cargo_test():
subprocess.run(
[
"cargo", "test",
"-p", "luminal_cuda_lite",
"cargo",
"test",
"-p",
"luminal_cuda_lite",
"--verbose",
"--",
"--test-threads=1",

View File

@@ -0,0 +1,231 @@
use anyhow::Result;
use luminal::prelude::*;
use crate::ops_parse::convolution::{conv_unfold, depthwise_conv};
use crate::pt2_schema::*;
use super::Translator;
const CONV_INPUT_ARG: usize = 0;
const CONV_WEIGHT_ARG: usize = 1;
const CONV_BIAS_ARG: usize = 2;
const CONV_STRIDE_ARG: usize = 3;
const CONV_PADDING_ARG: usize = 4;
const CONV_DILATION_ARG: usize = 5;
const CONV_GROUPS_ARG: usize = 6;
const CONVOLUTION_TRANSPOSED_ARG: usize = 6;
const CONVOLUTION_OUTPUT_PADDING_ARG: usize = 7;
const CONVOLUTION_GROUPS_ARG: usize = 8;
impl<'a> Translator<'a> {
/// Translate aten.conv{1,2,3}d.default and aten.convolution.default.
///
/// The PT2 export may omit defaulted trailing arguments entirely. In practice this means
/// conv{N}d.default can show up as just `(input, weight)` for the no-bias, stride=1,
/// padding=0, dilation=1, groups=1 case.
pub(crate) fn translate_conv(&mut self, node: &Node) -> Result<GraphTensor> {
let input = self.get_input_tensor(node, CONV_INPUT_ARG)?;
let weight = self.get_input_tensor(node, CONV_WEIGHT_ARG)?;
let bias = self.get_input_tensor(node, CONV_BIAS_ARG).ok();
let x_dims = input.dims();
let w_dims = weight.dims();
let rank = x_dims.len();
let spatial = rank - 2;
let stride = self
.get_ints_arg(node, CONV_STRIDE_ARG)
.unwrap_or_else(|_| vec![1; spatial]);
let padding = self
.get_ints_arg(node, CONV_PADDING_ARG)
.unwrap_or_else(|_| vec![0; spatial]);
let mut dilation = self
.get_ints_arg(node, CONV_DILATION_ARG)
.unwrap_or_else(|_| vec![1; spatial]);
let groups = if node.target == "torch.ops.aten.convolution.default" {
let transposed = self
.get_bool_arg(node, CONVOLUTION_TRANSPOSED_ARG)
.unwrap_or(false);
anyhow::ensure!(
!transposed,
"conv: ConvTranspose / transposed=true is not supported yet"
);
let output_padding = self
.get_ints_arg(node, CONVOLUTION_OUTPUT_PADDING_ARG)
.unwrap_or_else(|_| vec![0; spatial]);
anyhow::ensure!(
output_padding.iter().all(|&v| v == 0),
"conv: output_padding is not supported for non-transposed convolution"
);
self.get_int_arg(node, CONVOLUTION_GROUPS_ARG).unwrap_or(1) as usize
} else {
self.get_int_arg(node, CONV_GROUPS_ARG).unwrap_or(1) as usize
};
if dilation.len() != spatial {
dilation = vec![1; spatial];
}
let ch_out = w_dims[0]
.to_usize()
.ok_or_else(|| anyhow::anyhow!("conv: weight C_out must be concrete"))?;
let ch_in = x_dims[1]
.to_usize()
.ok_or_else(|| anyhow::anyhow!("conv: input C_in must be concrete"))?;
anyhow::ensure!(
stride.len() == spatial && padding.len() == spatial && dilation.len() == spatial,
"conv: stride/padding/dilation rank must match spatial rank {spatial}"
);
anyhow::ensure!(
groups > 0 && ch_in % groups == 0 && ch_out % groups == 0,
"conv: invalid group configuration (C_in={ch_in}, C_out={ch_out}, groups={groups})"
);
let ch_per_group = ch_in / groups;
let kernel_shape: Vec<usize> = w_dims[2..]
.iter()
.map(|d| {
d.to_usize()
.ok_or_else(|| anyhow::anyhow!("conv: kernel dims must be concrete"))
})
.collect::<Result<_>>()?;
let kernel_product: usize = kernel_shape.iter().product();
// ATen uses symmetric padding (same begin/end)
let stride_u: Vec<usize> = stride.iter().map(|&v| v as usize).collect();
let padding_u: Vec<usize> = padding.iter().map(|&v| v as usize).collect();
let dilation_u: Vec<usize> = dilation.iter().map(|&v| v as usize).collect();
let mut out = if groups > 1 {
let group_out = ch_out / groups;
if ch_per_group == 1 {
// Depthwise (including channel multiplier > 1): avoid per-channel slicing.
depthwise_conv(
input,
weight,
&kernel_shape,
&stride_u,
&dilation_u,
&padding_u,
&padding_u,
ch_in,
group_out,
kernel_product,
spatial,
)
} else {
// General grouped: pre-pad full input then slice per group
let padded_input = {
let mut pad_spec: Vec<(Expression, Expression)> =
vec![(0.into(), 0.into()); 2 + spatial];
for i in 0..spatial {
pad_spec[2 + i] = (padding_u[i].into(), padding_u[i].into());
}
input.pad(pad_spec, 0.0)
};
let no_pad = vec![0usize; spatial];
let mut group_outputs = Vec::with_capacity(groups);
for g in 0..groups {
let x_g = slice_channel_group(padded_input, g, ch_per_group, spatial);
let w_g =
slice_weight_group(weight, g, group_out, ch_per_group * kernel_product);
group_outputs.push(conv_unfold(
x_g,
w_g,
&kernel_shape,
&stride_u,
&dilation_u,
&no_pad,
&no_pad,
ch_per_group,
group_out,
spatial,
));
}
let mut result = group_outputs[0];
for g_out in &group_outputs[1..] {
result = result.concat_along(*g_out, 1);
}
result
}
} else {
let mut w_flat = weight;
w_flat.shape = ShapeTracker::new_with_element_bits(
vec![ch_out, ch_in * kernel_product],
weight.dtype.bits(),
);
conv_unfold(
input,
w_flat,
&kernel_shape,
&stride_u,
&dilation_u,
&padding_u,
&padding_u,
ch_in,
ch_out,
spatial,
)
};
if let Some(b) = bias {
let out_dims = out.dims();
let mut b_expanded = b.expand_dim(0, 1);
for i in 0..spatial {
b_expanded = b_expanded.expand_dim(2 + i, out_dims[2 + i]);
}
out += b_expanded;
}
Ok(out)
}
}
/// Slice input channels for one group.
/// Caller must pre-pad `x` so no additional padding is applied to the slice.
fn slice_channel_group(
x: GraphTensor,
g: usize,
ch_per_group: usize,
spatial: usize,
) -> GraphTensor {
let start = g * ch_per_group;
let end = start + ch_per_group;
let dims = x.dims();
let rank = 2 + spatial;
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(rank);
slices.push((0.into(), dims[0]));
slices.push((start.into(), end.into()));
for dim in dims.iter().take(rank).skip(2) {
slices.push((0.into(), *dim));
}
x.slice(slices)
}
/// Slice and flatten weight for one group.
fn slice_weight_group(
w: GraphTensor,
g: usize,
group_out: usize,
flat_inner: usize,
) -> GraphTensor {
let start = g * group_out;
let end = start + group_out;
let w_dims = w.dims();
let mut slices: Vec<(Expression, Expression)> = Vec::with_capacity(w_dims.len());
slices.push((start.into(), end.into()));
for dim in w_dims.iter().skip(1) {
slices.push((0.into(), *dim));
}
// Materialize through Add: binary op outputs are contiguous in Luminal, which makes the
// following flatten safe for the sliced weight buffer.
let w_sliced = w.slice(slices) + 0.0;
let mut w_flat = w_sliced;
w_flat.shape =
ShapeTracker::new_with_element_bits(vec![group_out, flat_inner], w_sliced.dtype.bits());
w_flat
}

View File

@@ -135,6 +135,12 @@ impl<'a> Translator<'a> {
// Linear
"torch.ops.aten.linear.default" => self.translate_linear(node)?,
// Convolution
"torch.ops.aten.conv1d.default"
| "torch.ops.aten.conv2d.default"
| "torch.ops.aten.conv3d.default"
| "torch.ops.aten.convolution.default" => self.translate_conv(node)?,
// Reduction ops
"torch.ops.aten.sum.dim_IntList" => self.translate_reduction(node, ReductionOp::Sum)?,
"torch.ops.aten.mean.dim" => self.translate_reduction(node, ReductionOp::Mean)?,
@@ -410,10 +416,11 @@ impl<'a> Translator<'a> {
return Ok(());
}
// Split
// Split / Chunk
"torch.ops.aten.split.Tensor" | "torch.ops.aten.split_with_sizes.default" => {
self.translate_split(node)?
}
"torch.ops.aten.chunk.default" => self.translate_chunk(node)?,
// One-hot
"torch.ops.aten.one_hot.default" => self.translate_one_hot(node)?,

View File

@@ -3,6 +3,7 @@
//! Walks the parsed PT2 graph and constructs an equivalent Luminal computation graph.
mod binary;
mod conv;
mod dispatch;
mod matmul;
mod movement;

View File

@@ -6,6 +6,10 @@ use crate::pt2_util::*;
use super::Translator;
const CHUNK_INPUT_ARG: usize = 0;
const CHUNK_NUM_CHUNKS_ARG: usize = 1;
const CHUNK_DIM_ARG: usize = 2;
impl<'a> Translator<'a> {
pub(crate) fn translate_reshape(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
@@ -430,6 +434,41 @@ impl<'a> Translator<'a> {
}
}
/// chunk(tensor, n_chunks, dim) -> splits tensor into n_chunks equal parts
pub(crate) fn translate_chunk(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, CHUNK_INPUT_ARG)?;
let n_chunks = self.get_int_arg(node, CHUNK_NUM_CHUNKS_ARG)? as usize;
let dim = if node.inputs.len() > CHUNK_DIM_ARG {
self.get_int_arg(node, CHUNK_DIM_ARG).unwrap_or(0)
} else {
0
};
let dim = normalize_dim(dim, a.shape.len());
let total = a.shape.dims[dim]
.to_usize()
.ok_or_else(|| anyhow::anyhow!("chunk requires concrete dim size"))?;
let chunk_size = total.div_ceil(n_chunks);
let output_names: Vec<String> = node
.outputs
.first()
.and_then(|o| o.as_tensors.as_ref())
.map(|ts| ts.iter().map(|t| t.name.clone()).collect())
.unwrap_or_default();
for (i, out_name) in output_names.iter().enumerate() {
let start = i * chunk_size;
let end = ((i + 1) * chunk_size).min(total);
if start < total {
let chunk = a.slice_along(start..end, dim);
self.tensors.insert(out_name.clone(), chunk);
}
}
Ok(a.slice_along(0..chunk_size.min(total), dim))
}
pub(crate) fn translate_split(&mut self, node: &Node) -> Result<GraphTensor> {
let a = self.get_input_tensor(node, 0)?;
let split_size = self.get_int_arg(node, 1)? as usize;

View File

@@ -215,11 +215,39 @@ from test_models import (
WhereWithConstantModel,
# Xor model
XorTestModel,
# Conv models
Conv1dNoPadModel,
Conv1dSamePadModel,
Conv1dBiasModel,
Conv2dNoPadModel,
Conv2dSamePadModel,
Conv2dBiasModel,
Conv2dStrideModel,
Conv2dDilationModel,
Conv3dSamePadModel,
DepthwiseConv1dModel,
DepthwiseConv2dModel,
DepthwiseMultiplierConv2dModel,
GroupedConv2dModel,
GroupedConv2dGroups3Model,
MambaConvBlockModel,
)
from luminal import luminal_backend
def _compile_for_export_mode(
model: torch.nn.Module, export_mode: str | None = None
) -> Callable:
if export_mode is None:
return torch.compile(model, backend=luminal_backend)
return torch.compile(
model,
backend=luminal_backend,
options={"export_mode": export_mode},
)
def test_add(device: torch.device):
add_test_model: torch.nn.Module = AddTestModel().to(device)
add_test_mode_compiled: Callable = torch.compile(
@@ -2015,3 +2043,208 @@ def test_dtype_float32(device: torch.device):
output: torch.Tensor = model_compiled(x)
assert output.dtype == torch.float32, f"Expected float32 output, got {output.dtype}"
assert torch.allclose(output, original)
# ========== Convolution Tests ==========
def _run_conv1d_no_pad(device: torch.device, export_mode: str | None = None):
"""Conv1d without padding: output length = input - (kernel-1)."""
model: torch.nn.Module = Conv1dNoPadModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_conv1d_no_pad(device: torch.device):
_run_conv1d_no_pad(device)
def test_conv1d_no_pad_pt2(device: torch.device):
_run_conv1d_no_pad(device, "pt2")
def test_conv1d_same_pad(device: torch.device):
"""Conv1d with padding=1: output length == input length."""
model: torch.nn.Module = Conv1dSamePadModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_conv1d_bias(device: torch.device):
"""Conv1d with bias term."""
model: torch.nn.Module = Conv1dBiasModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(2, 8, 32, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def _run_conv2d_no_pad(device: torch.device, export_mode: str | None = None):
"""Conv2d without padding: output spatial = input - (kernel-1)."""
model: torch.nn.Module = Conv2dNoPadModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_conv2d_no_pad(device: torch.device):
_run_conv2d_no_pad(device)
def test_conv2d_no_pad_pt2(device: torch.device):
_run_conv2d_no_pad(device, "pt2")
def test_conv2d_same_pad(device: torch.device):
"""Conv2d with padding=1: output spatial == input spatial."""
model: torch.nn.Module = Conv2dSamePadModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_conv2d_bias(device: torch.device):
"""Conv2d with bias term."""
model: torch.nn.Module = Conv2dBiasModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_conv2d_stride(device: torch.device):
"""Conv2d with stride=2: output spatial dims halved."""
model: torch.nn.Module = Conv2dStrideModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(1, 3, 8, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def _run_conv2d_dilation(device: torch.device, export_mode: str | None = None):
"""Conv2d with dilation=2 preserves the expected spatial shape and values."""
model: torch.nn.Module = Conv2dDilationModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
x: torch.Tensor = torch.randn(2, 8, 17, 19, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_conv2d_dilation(device: torch.device):
_run_conv2d_dilation(device)
def test_conv2d_dilation_pt2(device: torch.device):
_run_conv2d_dilation(device, "pt2")
def _run_conv3d_same_pad(device: torch.device, export_mode: str | None = None):
"""Conv3d exercises the spatial=3 unfold/permute/split path."""
model: torch.nn.Module = Conv3dSamePadModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
x: torch.Tensor = torch.randn(2, 4, 6, 7, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_conv3d_same_pad(device: torch.device):
_run_conv3d_same_pad(device)
def test_conv3d_same_pad_pt2(device: torch.device):
_run_conv3d_same_pad(device, "pt2")
def test_depthwise_conv1d(device: torch.device):
"""Depthwise Conv1d with groups=in_channels, as used in Mamba."""
model: torch.nn.Module = DepthwiseConv1dModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(2, 16, 32, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_depthwise_conv2d(device: torch.device):
"""Depthwise Conv2d with groups=in_channels."""
model: torch.nn.Module = DepthwiseConv2dModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(1, 8, 8, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def _run_depthwise_multiplier_conv2d(
device: torch.device, export_mode: str | None = None
):
"""Depthwise Conv2d with multiplier > 1 should preserve both output channels per input channel."""
model: torch.nn.Module = DepthwiseMultiplierConv2dModel().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
x: torch.Tensor = torch.randn(2, 8, 9, 9, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_depthwise_multiplier_conv2d(device: torch.device):
_run_depthwise_multiplier_conv2d(device)
def test_depthwise_multiplier_conv2d_pt2(device: torch.device):
_run_depthwise_multiplier_conv2d(device, "pt2")
def test_grouped_conv2d(device: torch.device):
"""Conv2d with groups=4 (grouped, not depthwise)."""
model: torch.nn.Module = GroupedConv2dModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(1, 16, 8, 8, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def _run_grouped_conv2d_groups3_batch4(
device: torch.device, export_mode: str | None = None
):
"""Grouped Conv2d with groups=3 and batch>1 exercises the pre-pad + slice path."""
model: torch.nn.Module = GroupedConv2dGroups3Model().to(device)
model_compiled: Callable = _compile_for_export_mode(model, export_mode)
x: torch.Tensor = torch.randn(4, 12, 11, 9, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)
def test_grouped_conv2d_groups3_batch4(device: torch.device):
_run_grouped_conv2d_groups3_batch4(device)
def test_grouped_conv2d_groups3_batch4_pt2(device: torch.device):
_run_grouped_conv2d_groups3_batch4(device, "pt2")
def test_mamba_conv_block(device: torch.device):
"""Minimal Mamba-style block: depthwise Conv1d with causal gating (end-to-end)."""
model: torch.nn.Module = MambaConvBlockModel().to(device)
model_compiled: Callable = torch.compile(model, backend=luminal_backend)
x: torch.Tensor = torch.randn(2, 64, 16, device=device)
original: torch.Tensor = model(x)
output: torch.Tensor = model_compiled(x)
assert torch.allclose(output, original, atol=1e-4)

View File

@@ -1830,3 +1830,202 @@ class LlamaTransformerBlockModel(torch.nn.Module):
h = x + self.attn(self.input_norm(x))
out = h + self.mlp(self.post_attn_norm(h))
return out
# ---------------------------------------------------------------------------
# Convolution models
# ---------------------------------------------------------------------------
class Conv1dNoPadModel(torch.nn.Module):
"""Conv1d with no padding: output length shrinks by (kernel-1)."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=0, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Conv1dSamePadModel(torch.nn.Module):
"""Conv1d with same-size padding (output length == input length)."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Conv1dBiasModel(torch.nn.Module):
"""Conv1d with bias."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(8, 16, kernel_size=3, padding=1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Conv2dNoPadModel(torch.nn.Module):
"""Conv2d with no padding: output spatial dims shrink by (kernel-1)."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=0, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Conv2dSamePadModel(torch.nn.Module):
"""Conv2d with same-size padding."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Conv2dBiasModel(torch.nn.Module):
"""Conv2d with bias."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Conv2dStrideModel(torch.nn.Module):
"""Conv2d with stride=2 (output dims halved)."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
3, 16, kernel_size=3, stride=2, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Conv2dDilationModel(torch.nn.Module):
"""Conv2d with dilation=2 and padding chosen to preserve spatial size."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
8, 16, kernel_size=3, dilation=2, padding=2, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Conv3dSamePadModel(torch.nn.Module):
"""Conv3d with padding=1 to preserve spatial dimensions."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv3d(4, 8, kernel_size=3, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class DepthwiseConv1dModel(torch.nn.Module):
"""Depthwise Conv1d as used in Mamba (groups == in_channels)."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv1d(
16, 16, kernel_size=4, groups=16, padding=3, bias=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Causal truncation: keep only the first L positions
return self.conv(x)[:, :, : x.shape[2]]
class DepthwiseConv2dModel(torch.nn.Module):
"""Depthwise Conv2d (groups == in_channels)."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
8, 8, kernel_size=3, groups=8, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class DepthwiseMultiplierConv2dModel(torch.nn.Module):
"""Depthwise Conv2d with channel multiplier 2 (out_channels = 2 * in_channels)."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
8, 16, kernel_size=3, groups=8, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class GroupedConv2dModel(torch.nn.Module):
"""Conv2d with groups=4 (not depthwise, but grouped)."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
16, 32, kernel_size=3, groups=4, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class GroupedConv2dGroups3Model(torch.nn.Module):
"""Conv2d with groups=3 and ch_per_group=4."""
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
12, 12, kernel_size=3, groups=3, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class MambaConvBlockModel(torch.nn.Module):
"""Minimal Mamba-style SSM block: Linear -> split -> depthwise Conv1d -> SiLU gate -> Linear.
This is the core conv pattern used in Mamba / Mamba-2 models.
"""
def __init__(self, d_model: int = 16, d_conv: int = 4, expand: int = 2) -> None:
super().__init__()
d_inner = d_model * expand
self.in_proj = torch.nn.Linear(d_model, d_inner * 2, bias=False)
self.conv1d = torch.nn.Conv1d(
d_inner, d_inner, d_conv, groups=d_inner, padding=d_conv - 1, bias=True
)
self.out_proj = torch.nn.Linear(d_inner, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, seq_len, _ = x.shape
xz = self.in_proj(x)
x_part, z = xz.chunk(2, dim=-1)
x_part = self.conv1d(x_part.transpose(1, 2))[:, :, :seq_len].transpose(1, 2)
return self.out_proj(
torch.nn.functional.silu(x_part) * torch.nn.functional.silu(z)
)