perf(flex): bulk-copy inner-contiguous run in to_contiguous (#5019)

This commit is contained in:
John Kaczman
2026-05-29 08:24:00 -04:00
committed by GitHub
parent 1b6dd60361
commit cc9664db43

View File

@@ -377,6 +377,11 @@ impl FlexTensor {
copy_2d_tiled(
&mut dst, src, offset, shape[0], shape[1], strides[0], strides[1],
);
} else if all_positive && strides.last().copied() == Some(1) {
// Since the inner stride is 1, we can bulk-copy
// the contiguous inner run for each outer index.
unsafe { dst.set_len(n) };
copy_inner_contiguous_run(&mut dst, src, offset, shape, strides);
} else {
// General fallback: covers negative strides (flipped
// tensors) and ND layouts that can't collapse to ≤2D.
@@ -546,6 +551,90 @@ fn collapse_for_copy(shape: &[usize], strides: &[isize]) -> CollapsedLayout {
out
}
/// Copy a strided source into a contiguous destination by treating the
/// layout as a set of strided outer dims wrapping a contiguous inner run.
///
/// Precondition: strides are all positive and the innermost stride is 1
#[inline]
fn copy_inner_contiguous_run<E: Copy>(
dst: &mut [E],
src: &[E],
offset: isize,
shape: &[usize],
strides: &[isize],
) {
debug_assert!(!shape.is_empty());
debug_assert_eq!(*strides.last().unwrap(), 1, "innermost stride must be 1");
// We want the maximal trailing run that is contiguous in src.
// Walk inward-out, a dim dim_idx extends the contiguous run if AND ONLY IF ("iff")
// strides[dim_idx] == product of inner shapes seen so far.
let mut expected_stride: isize = 1;
let mut split = shape.len(); // first OUTER index (dims [split..] are inner)
while split > 0 {
let dim_idx = split - 1;
if strides[dim_idx] == expected_stride {
expected_stride = expected_stride
.checked_mul(shape[dim_idx] as isize)
.expect("contiguous run length overflows isize");
split -= 1;
} else {
break;
}
}
let outer_shape = &shape[..split];
let outer_strides = &strides[..split];
let inner_run_length: usize = shape[split..].iter().product();
// Whole tensor is one contiguous run! We can bulk-copy the whole thing.
if outer_shape.is_empty() {
let src_index = offset as usize;
dst[..inner_run_length].copy_from_slice(&src[src_index..src_index + inner_run_length]);
return;
}
if inner_run_length == 0 || outer_shape.contains(&0) {
return;
}
// Walk the outer index space in row-major order via an odometer:
// each iteration ticks the rightmost counter, and so on overflow we
// reset it and carry +1 into the dim to its left.
// Row-major is what we want because dst is contiguous (consecutive outer-index
// tuples land in consecutive inner_run_length slots, so we just bump
// dst_position each step instead of recomputing it)
// We use this instead of N nested for loops because the rank isn't
// known at compile time.
let mut counter = [0usize; COLLAPSE_MAX_RANK];
let outer_dim_count = outer_shape.len();
let mut dst_position = 0usize;
loop {
let mut src_start = offset;
for dim_idx in 0..outer_dim_count {
src_start += counter[dim_idx] as isize * outer_strides[dim_idx];
}
let src_index = src_start as usize;
dst[dst_position..dst_position + inner_run_length]
.copy_from_slice(&src[src_index..src_index + inner_run_length]);
dst_position += inner_run_length;
let mut dim_idx = outer_dim_count;
loop {
if dim_idx == 0 {
return;
}
dim_idx -= 1;
counter[dim_idx] += 1;
if counter[dim_idx] < outer_shape[dim_idx] {
break;
}
counter[dim_idx] = 0;
}
}
}
/// Tiled 2D copy from a strided source into a contiguous destination.
/// The loop nesting is chosen so the innermost read walks whichever
/// source stride is smaller, which keeps the hot loop in cache even
@@ -763,6 +852,43 @@ mod tests {
assert_eq!(values, expected.as_slice());
}
/// Regression: the ShuffleNet channel shuffle (reshape -> permute([0,2,1,3,4]) -> reshape)
/// collapses to a 3D layout with stride-1 innermost.
/// Before this fix that went into the StridedIter scalar walk in the final else in copy_contiguous
/// Now, we use a new inner-contiguous branch that bulk-copies the trailing run
#[test]
fn test_to_contiguous_3d_inner_stride1_matches_naive() {
let dims = [1, 2, 4, 3, 3]; // [N, G, C, H, W]
let n: usize = dims.iter().product();
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let t = FlexTensor::from_data(TensorData::new(data.clone(), dims.to_vec()));
let permuted = t.permute(&[0, 2, 1, 3, 4]); // swap G and C from the original dims in the permute
assert!(!permuted.is_contiguous());
let contiguous_data = permuted.to_contiguous();
assert!(contiguous_data.is_contiguous());
assert_eq!(contiguous_data.shape().to_vec(), vec![1, 4, 2, 3, 3]);
// Now we manually do a strided walk to be able to check that against a naive solution.
// output[0, c, g, h, w] = input[0, g, c, h, w].
let mut expected = Vec::with_capacity(n);
for c in 0..4 {
for g in 0..2 {
for h in 0..3 {
for w in 0..3 {
let idx = g * 36 + c * 9 + h * 3 + w;
expected.push(data[idx]);
}
}
}
}
// Verify the naive solution matches (the manual strided walk).
let result_data = contiguous_data.into_data();
let values = result_data.as_slice::<f32>().unwrap();
assert_eq!(values, expected.as_slice());
}
/// Exercise the `row_stride > col_stride` branch of the 2D tiled
/// copy (the ConvNeXt case hits the other branch).
#[test]
@@ -832,7 +958,7 @@ mod tests {
// Both point to same data
assert!(core::ptr::eq(
tensor.bytes().as_ptr(),
cloned.bytes().as_ptr()
cloned.bytes().as_ptr(),
));
}