Compare commits

...

1 Commits

Author SHA1 Message Date
Joe Fioti
75e4e6be0a Simplify example mains and trim CUDA profiling output (#339)
* Simplify example mains and trim CUDA profiling output

* Simplify model examples and adjust CUDA profiling output

* Simplify example model setup and CUDA profiling output
2026-05-29 23:37:13 -04:00
45 changed files with 1343 additions and 1385 deletions

View File

@@ -21,8 +21,7 @@ let b = cx.tensor((1, 4));
let c = a.matmul(b).output();
// Compile
cx.build_search_space::<NativeRuntime>();
let mut rt = cx.search(NativeRuntime::default(), 1);
let mut rt = cx.compile(NativeRuntime::default(), CompileOptions::default());
// Set input tensors
rt.set_data(a, vec![1.0, 2.0, 3.0]);

View File

@@ -50,7 +50,7 @@ fn run_metal_pattern_benchmark(
}
}
let mut rt = cx.search(rt, CompileOptions::new(5));
let mut rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
let mut bench_metrics = None;

View File

@@ -50,7 +50,7 @@ fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Opt
rt.set_data(*node, &data);
}
let rt = cx.search(rt, CompileOptions::new(5));
let rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
Some(PreparedBench {
rt,

View File

@@ -499,7 +499,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result1 = rt.get_f32(c);
rt.execute(&cx.dyn_map);
@@ -531,7 +531,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
let mut results = Vec::new();
for _ in 0..5 {
rt.execute(&cx.dyn_map);
@@ -569,7 +569,7 @@ mod tests {
rt.set_data(b, data_b.clone());
cx.set_dim('s', size);
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a
.iter()
@@ -611,7 +611,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
let eps = dtype_epsilon(luminal::dtype::DType::F32);
@@ -642,7 +642,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
for _ in 0..10 {
rt.execute(&cx.dyn_map);
}
@@ -675,7 +675,7 @@ mod tests {
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Initial execution
rt.execute(&cx.dyn_map);

View File

@@ -1556,37 +1556,9 @@ impl Runtime for CudaRuntime {
self.profiling = false;
let duration = durations.iter().sum::<std::time::Duration>() / durations.len() as u32;
let total_bytes: usize = self
.last_kernel_stats
.iter()
.map(|s| s.bytes_loaded + s.bytes_stored)
.sum::<usize>();
let total_flops: usize = self
.last_kernel_stats
.iter()
.map(|s| s.flops)
.sum::<usize>();
let aggregate_bw = if self.last_total_time_us > 0.0 {
(total_bytes as f64) / (self.last_total_time_us * 1e-6) / 1e9
} else {
0.0
};
let aggregate_tf = if self.last_total_time_us > 0.0 {
(total_flops as f64) / (self.last_total_time_us * 1e-6) / 1e12
} else {
0.0
};
let peak_bw = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
let peak_tf = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
let mbu = peak_bw.map(|p| aggregate_bw / p as f64);
let mfu = peak_tf.map(|p| aggregate_tf / p as f64);
let duration_str = format_duration_precise(&duration);
let mbu_str = mbu.map_or("-".to_string(), |v| format!("{:.1}%", v * 100.0));
let mfu_str = mfu.map_or("-".to_string(), |v| format!("{:.1}%", v * 100.0));
let display = format!(
"{duration_str} | MBU: {mbu_str} | MFU: {mfu_str} [KRN: {} HOST: {}]",
"{duration_str} | [KRN: {} HOST: {}]",
llir_graph
.node_weights()
.filter(|n| n.to_dialect::<dyn KernelOp>().is_some())
@@ -2079,18 +2051,10 @@ impl CudaRuntime {
if !self.last_kernel_stats.is_empty() {
println!("\n=== Kernel Execution Statistics ===\n");
println!(
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
"Kernel",
"Time (us)",
"Loaded",
"Stored",
"Agg FLOPS",
"BW (GB/s)",
"TFLOPS",
"MBU",
"MFU"
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12}",
"Kernel", "Time (us)", "Loaded", "Stored", "Agg FLOPS", "BW (GB/s)", "TFLOPS"
);
println!("{}", "-".repeat(116));
println!("{}", "-".repeat(92));
for s in &self.last_kernel_stats {
self.print_stat_row(
s.name,
@@ -2101,29 +2065,20 @@ impl CudaRuntime {
s.flops,
s.bandwidth_gbps,
s.tflops,
peak_bw,
peak_tf,
);
}
println!("{}", "-".repeat(116));
println!("{}", "-".repeat(92));
}
// Print aggregate stats
println!("\n=== Aggregate Statistics ===\n");
println!(
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
"", "Time (us)", "Loaded", "Stored", "Agg FLOPS", "BW (GB/s)", "TFLOPS", "MBU", "MFU"
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12}",
"", "Time (us)", "Loaded", "Stored", "Agg FLOPS", "BW (GB/s)", "TFLOPS"
);
println!("{}", "-".repeat(116));
let (mbu, mfu) = match (peak_bw, peak_tf) {
(Some(pb), Some(pt)) => (
format!("{:.1}%", aggregate_bw / pb as f64 * 100.0),
format!("{:.1}%", aggregate_tf / pt as f64 * 100.0),
),
_ => ("-".into(), "-".into()),
};
println!("{}", "-".repeat(92));
println!(
"{:<20} {:>12.2} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
"{:<20} {:>12.2} {:>12} {:>12} {:>12} {:>12} {:>12}",
"Total",
self.last_total_time_us,
format_size(total_bytes_loaded),
@@ -2131,8 +2086,6 @@ impl CudaRuntime {
format_flops(total_flops),
format!("{:.2}", aggregate_bw),
format!("{:.4}", aggregate_tf),
mbu,
mfu
);
if let (Some(pb), Some(pt)) = (peak_bw, peak_tf) {
@@ -2152,8 +2105,6 @@ impl CudaRuntime {
flops: usize,
bw: f64,
tf: f64,
peak_bw: Option<usize>,
peak_tf: Option<usize>,
) {
let total = loaded + stored;
let ld = if loaded > 0 {
@@ -2181,21 +2132,13 @@ impl CudaRuntime {
} else {
"-".into()
};
let mbu = peak_bw
.filter(|_| total > 0)
.map(|p| format!("{:.1}%", bw / p as f64 * 100.0))
.unwrap_or("-".into());
let mfu = peak_tf
.filter(|_| flops > 0)
.map(|p| format!("{:.1}%", tf / p as f64 * 100.0))
.unwrap_or("-".into());
match count {
Some(c) => println!(
"{name:<20} {time_us:>12.2} {c:>8} {ld:>12} {st:>12} {fl:>12} {bw_s:>12} {tf_s:>12} {mbu:>8} {mfu:>8}"
"{name:<20} {time_us:>12.2} {c:>8} {ld:>12} {st:>12} {fl:>12} {bw_s:>12} {tf_s:>12}"
),
None => println!(
"{name:<20} {time_us:>12.2} {ld:>12} {st:>12} {fl:>12} {bw_s:>12} {tf_s:>12} {mbu:>8} {mfu:>8}"
"{name:<20} {time_us:>12.2} {ld:>12} {st:>12} {fl:>12} {bw_s:>12} {tf_s:>12}"
),
}
}

View File

@@ -46,7 +46,11 @@ fn test_bucket_dispatch_simple() {
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Test bucket 1: s=1
cx.set_dim('s', 1);
@@ -91,7 +95,11 @@ fn test_bucket_matmul_dynamic() {
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Execute at s=1
cx.set_dim('s', 1);
@@ -146,7 +154,11 @@ fn test_bucket_results_match_unbucketed() {
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
rt1.set_data(a1, input_data.clone());
let mut rng1 = SmallRng::seed_from_u64(seed);
rt1 = cx1.search_with_rng(rt1, CompileOptions::new(5), &mut rng1);
rt1 = cx1.search_with_rng(
rt1,
CompileOptions::default().search_graph_limit(5),
&mut rng1,
);
rt1.set_data(a1, input_data.clone());
rt1.execute(&cx1.dyn_map);
let result_unbucketed = rt1.get_f32(b1);
@@ -158,7 +170,11 @@ fn test_bucket_results_match_unbucketed() {
let mut rt2 = CudaRuntime::initialize(stream.clone());
rt2.set_data(a2, input_data.clone());
let mut rng2 = SmallRng::seed_from_u64(seed);
rt2 = cx2.search_with_rng(rt2, CompileOptions::new(5), &mut rng2);
rt2 = cx2.search_with_rng(
rt2,
CompileOptions::default().search_graph_limit(5),
&mut rng2,
);
rt2.set_data(a2, input_data.clone());
rt2.execute(&cx2.dyn_map);
let result_bucketed = rt2.get_f32(b2);
@@ -186,7 +202,11 @@ fn test_bucket_out_of_range_panics() {
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
// s=10 is outside all buckets — should panic
cx.set_dim('s', 10);
@@ -211,7 +231,11 @@ fn test_bucket_no_buckets_backward_compat() {
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
rt.set_data(a, input_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
rt.set_data(a, input_data.clone());
rt.execute(&cx.dyn_map);
@@ -257,7 +281,11 @@ fn test_bucket_switch_preserves_weights() {
rt.set_data(b_tensor, b_data.clone());
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Execute with bucket 1 (s=1)
cx.set_dim('s', 1);
@@ -311,7 +339,11 @@ fn test_bucket_multiple_executions_same_bucket() {
cx.set_dim('s', 1);
rt.set_data(a, vec![1.0f32; 4]);
let mut rng = SmallRng::seed_from_u64(42);
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(3),
&mut rng,
);
// Execute at different sizes within the same bucket
for s in [1, 2, 4, 8] {

View File

@@ -307,7 +307,7 @@ fn test_scatter_kv_cache_roundtrip() {
rt.set_data(src, vec![10.0f32]);
rt.set_data(indexes, vec![0i32]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Print and verify which scatter variant was selected
let scatter_names: Vec<_> = rt
@@ -427,7 +427,11 @@ fn test_scatter_dual_cache() {
// Use seeded search for deterministic variant selection.
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(5),
&mut rng,
);
// Print and verify selected variants
let scatter_names: Vec<_> = rt
@@ -554,7 +558,11 @@ fn test_scatter_rows_dynamic_prefill_roundtrip() {
rt.set_data(gather_idx, scatter);
rt.set_data(cache, (0..SLOTS * D).map(|i| i as f32).collect::<Vec<_>>());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
assert_eq!(rt.get_f32(gathered), expected_gather);
@@ -763,7 +771,11 @@ fn test_tiny_gqa_attention_batched_matches_sequential_prefill() {
rt.set_data(k_cache, zero_k.clone());
rt.set_data(v_cache, zero_v.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let batched_attn = rt.get_f32(attn_out);
let batched_k = rt.get_f32(k_out);
@@ -865,7 +877,11 @@ fn test_original_gqa_attention_batched_matches_sequential_prefill() {
rt.set_data(k_cache, zero_k.clone());
rt.set_data(v_cache, zero_v.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let batched_attn = rt.get_f32(attn_out);
let batched_k = rt.get_f32(k_out);
@@ -937,7 +953,11 @@ fn test_dynamic_expanded_causal_mask_softmax() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(mask, mask_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(weights);
@@ -1007,7 +1027,11 @@ fn test_tiny_gqa_value_matmul_with_expanded_kv() {
rt.set_data(v_full, v_data.clone());
rt.set_data(mask, mask_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1073,7 +1097,11 @@ fn test_broadcast_merge_gqa_value_matmul_matches_cpu() {
rt.set_data(v_full, v_data.clone());
rt.set_data(weights, weights_data);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1124,7 +1152,11 @@ fn test_transpose_merge_split_roundtrip_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(x, x_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(roundtrip);
@@ -1171,7 +1203,11 @@ fn test_batched_moe_x_expand_matmul_matches_cpu() {
rt.set_data(x, x_data.clone());
rt.set_data(w, w_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1220,7 +1256,11 @@ fn test_batched_topk_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(topk);
@@ -1259,7 +1299,11 @@ fn test_batched_argsort_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(argsort);
@@ -1299,7 +1343,11 @@ fn test_dynamic_3d_sum_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(input, data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);
@@ -1356,7 +1404,11 @@ fn test_batched_argsort_ranks_axis1_matches_cpu() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(routing, routing_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(ranks);
@@ -1395,7 +1447,11 @@ fn test_dynamic_3d_flat_index_iota_rows() {
let mut rt = CudaRuntime::initialize(stream);
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(idx);
@@ -1438,7 +1494,11 @@ fn test_dynamic_2d_to_3d_gather_rows() {
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(data, data_values.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_i32(out);
@@ -1490,7 +1550,11 @@ fn test_batched_gather_experts_matches_cpu() {
rt.set_data(topk, topk_data.clone());
rt.set_data(weights, weights_data.clone());
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
rt = cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(10),
&mut rng,
);
rt.execute(&cx.dyn_map);
let got = rt.get_f32(out);

View File

@@ -89,7 +89,7 @@ fn run_reference_attention(
rt.set_data(q_t, q.to_vec());
rt.set_data(k_t, k.to_vec());
rt.set_data(v_t, v.to_vec());
rt = cx.search(rt, CompileOptions::new(3));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.set_data(q_t, q.to_vec());
rt.set_data(k_t, k.to_vec());

View File

@@ -61,7 +61,7 @@ fn generic_matmul_executes_noncontiguous_merged_head_projection() {
rt.set_data(attn, attn_data.as_slice());
rt.set_data(weight, weight_data.as_slice());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
assert!(
rt.kernel_names().contains(&"GenericMatmul"),
"expected GenericMatmul to be selected, kernels: {:?}",

View File

@@ -95,7 +95,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -156,7 +156,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(proj_w, proj_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -245,7 +245,7 @@ fn fuzz_layer_no_attn(
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -330,7 +330,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -518,7 +518,7 @@ mod gemma {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -655,7 +655,7 @@ mod qwen {
rt.set_data(input, input_data.clone());
rt.set_data(norm_w, norm_data.clone());
rt.set_data(embedding, emb_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);

View File

@@ -259,7 +259,7 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(stream);
rt.set_data(input, data);
rt = cx.search(rt, CompileOptions::new(10));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&cx.dyn_map);
let out_dim0 = rt.get_i32(sorted_dim0.id);
let out_dim1 = rt.get_i32(sorted_dim1.id);
@@ -600,7 +600,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
rt.set_data(token_ids, token_data.clone());
rt.set_data(embed_table, embed_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(output);

View File

@@ -31,7 +31,7 @@ pub fn kernel_add_bandwidth_test() {
let mut rt = CudaRuntime::initialize(stream.clone());
rt.set_data(a, data_a.clone());
rt.set_data(b, data_b.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
// Warm up
rt.execute(&cx.dyn_map);

View File

@@ -233,7 +233,9 @@ fn run_qwen_moe(include_glumoe: bool) -> Vec<f32> {
rt.set_data(model.router, router_data);
rt.set_data(model.gate_up_weights, gate_up_data);
rt.set_data(model.down_weights, down_data);
rt = model.graph.search(rt, CompileOptions::new(10));
rt = model
.graph
.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&model.graph.dyn_map);
rt.get_f32(model.output.id)
@@ -278,7 +280,9 @@ fn run_gemma_moe(include_glumoe: bool) -> Vec<f32> {
rt.set_data(model.per_expert_scale, per_expert_scale_data);
rt.set_data(model.gate_up_weights, gate_up_data);
rt.set_data(model.down_weights, down_data);
rt = model.graph.search(rt, CompileOptions::new(10));
rt = model
.graph
.search(rt, CompileOptions::default().search_graph_limit(10));
rt.execute(&model.graph.dyn_map);
rt.get_f32(model.output.id)

View File

@@ -50,7 +50,7 @@ fn rope_matches_cpu_reference() {
rt.set_data(x, x_data.clone());
rt.set_data(cos, cos_data.clone());
rt.set_data(sin, sin_data.clone());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let got = rt.get_f32(y.id);
@@ -98,7 +98,7 @@ fn rope_flux2_shape() {
rt.set_data(x, x_data.clone());
rt.set_data(cos, cos_data.clone());
rt.set_data(sin, sin_data.clone());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let got = rt.get_f32(y.id);

View File

@@ -280,7 +280,7 @@ fn test_mini_transformer_layer() {
// Use minimal search iterations to avoid excessive graph rewriting
// which can cause float drift through softmax/RMSNorm reordering
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -316,7 +316,7 @@ fn test_mini_transformer_two_layers() {
rt.set_data(*tensor, data.clone());
}
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -372,7 +372,7 @@ fn test_transformer_multi_seed() {
rt.set_data(*tensor, data.clone());
}
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -404,7 +404,7 @@ fn test_rms_norm_cuda() {
.collect();
rt.set_data(input, input_data.clone());
rt.set_data(weight, weight_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -447,7 +447,7 @@ fn test_self_attention_cuda() {
rt.set_data(wk, wk_data.clone());
rt.set_data(wv, wv_data.clone());
rt.set_data(wo, wo_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -491,7 +491,7 @@ fn test_swiglu_mlp_cuda() {
rt.set_data(w_gate, gate_data.clone());
rt.set_data(w_up, up_data.clone());
rt.set_data(w_down, down_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);
@@ -530,7 +530,7 @@ fn test_rolled_chained_scalar_muls() {
let mut rt = CudaRuntime::initialize(stream);
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
rt.set_data(x, x_data.clone());
rt = cx.search(rt, CompileOptions::new(3));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(out);

View File

@@ -331,7 +331,7 @@ pub fn fuzz_cuda_search_space_equivalence(
let mut native_rng = StdRng::seed_from_u64(config.seed);
let mut native_rt = cx.search_with_rng(
NativeRuntime::default(),
CompileOptions::new(1),
CompileOptions::default().search_graph_limit(1),
&mut native_rng,
);
for input in inputs {
@@ -701,7 +701,7 @@ pub fn test_unary_cuda<T: TestDType>(
let input_data = generator(n_elements, seed);
rt.set_data(a, input_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, b.id);
@@ -776,7 +776,7 @@ pub fn test_binary_cuda<T: TestDType>(
let b_data = b_generator(b_elements, seed.wrapping_add(1));
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = T::get_from_runtime(&rt, c.id);
@@ -844,7 +844,7 @@ pub fn test_mod(
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
rt.set_data(a, a_data.clone());
rt.set_data(b, b_data.clone());
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.execute(&cx.dyn_map);
let result = rt.get_f32(c);

View File

@@ -503,7 +503,8 @@ fn main() -> Result<(), Box<dyn Error>> {
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
runtime = cx.search(runtime, CompileOptions::new(SEARCH_GRAPHS));
let search_options = CompileOptions::default().search_graph_limit(SEARCH_GRAPHS);
runtime = cx.search(runtime, search_options);
println!(
" Search/compile: {:.2} s",
compile_start.elapsed().as_secs_f64()

View File

@@ -41,7 +41,11 @@ fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
fn search_candidates(cx: &mut Graph, rt: MetalRuntime, limit: usize) -> MetalRuntime {
let mut rng = StdRng::seed_from_u64(0);
cx.search_with_rng(rt, CompileOptions::new(limit), &mut rng)
cx.search_with_rng(
rt,
CompileOptions::default().search_graph_limit(limit),
&mut rng,
)
}
fn egraph_has_op(cx: &Graph, op_name: &str) -> bool {
@@ -301,7 +305,7 @@ fn dynamic_dim_sum_reduce_runs() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -322,7 +326,7 @@ fn metal_bucketed_dynamic_dim_dispatches_correct_graph() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, vec![1.0f32; 4]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
cx.set_dim('s', 1);
let s1_input = vec![1.0, 2.0, 3.0, 4.0];
@@ -350,7 +354,7 @@ fn metal_int_arithmetic_preserves_large_values() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(token, &[16_385i32]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -373,7 +377,7 @@ proptest! {
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -395,7 +399,7 @@ proptest! {
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -417,7 +421,7 @@ proptest! {
let mut rt = MetalRuntime::initialize(());
let input_values: Vec<f32> = values.into_iter().take(len).collect();
rt.set_data(input, &input_values);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -449,7 +453,7 @@ fn metal_simple_add() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -469,7 +473,7 @@ fn metal_simple_mul() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -487,7 +491,7 @@ fn metal_simple_exp2() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[0.0, 1.0, 2.0, 3.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -504,7 +508,7 @@ fn metal_simple_log2() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 4.0, 8.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -529,7 +533,7 @@ fn metal_simple_sin() {
3.0 * std::f32::consts::FRAC_PI_2,
],
);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -546,7 +550,7 @@ fn metal_simple_sqrt() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 4.0, 9.0, 16.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -563,7 +567,7 @@ fn metal_simple_recip() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, 2.0, 4.0, 5.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -582,7 +586,7 @@ fn metal_simple_mod() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[7.0, 10.0, 15.0, 8.5]);
rt.set_data(b, &[3.0, 4.0, 6.0, 2.5]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -601,7 +605,7 @@ fn metal_simple_less_than() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 5.0, 3.0, 4.0]);
rt.set_data(b, &[2.0, 3.0, 3.0, 5.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -621,7 +625,7 @@ fn metal_simple_sum_reduce() {
let mut rt = MetalRuntime::initialize(());
// [[1,2,3,4], [5,6,7,8]] -> [10, 26]
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -640,7 +644,7 @@ fn metal_simple_max_reduce() {
let mut rt = MetalRuntime::initialize(());
// [[1,4,2,3], [8,5,7,6]] -> [4, 8]
rt.set_data(input, &[1.0, 4.0, 2.0, 3.0, 8.0, 5.0, 7.0, 6.0]);
rt = cx.search(rt, CompileOptions::new(5));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -657,7 +661,7 @@ fn metal_f16_cast_roundtrip() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
rt = cx.search(rt, CompileOptions::new(3));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -678,7 +682,7 @@ fn metal_f16_intermediate_add_roundtrip() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
rt = cx.search(rt, CompileOptions::new(3));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1026,7 +1030,7 @@ fn metal_rms_norm() {
rt.set_data(input, &input_data);
rt.set_data(weight, &weight_data);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1066,7 +1070,7 @@ fn metal_self_attention() {
rt.set_data(wk, &wk_data);
rt.set_data(wv, &wv_data);
rt.set_data(wo, &wo_data);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1125,7 +1129,7 @@ fn metal_self_attention_f16_weights() {
rt.set_data(wk, to_f16_vec(&wk_data));
rt.set_data(wv, to_f16_vec(&wv_data));
rt.set_data(wo, to_f16_vec(&wo_data));
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1169,7 +1173,7 @@ fn metal_swiglu_mlp() {
rt.set_data(w_gate, &gate_data);
rt.set_data(w_up, &up_data);
rt.set_data(w_down, &down_data);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1219,7 +1223,7 @@ fn metal_mini_transformer_layer() {
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1285,7 +1289,7 @@ fn metal_mini_transformer_layer_f16_intermediate() {
for (tensor, data) in &weight_data {
rt.set_data(*tensor, data);
}
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1327,7 +1331,7 @@ fn test_scatter_basic() {
rt.set_data(src, &[10.0, 20.0, 30.0]);
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1349,7 +1353,7 @@ fn test_scatter_buffer_roundtrip() {
rt.set_data(src, &[0.0]);
rt.set_data(indexes, &[0.0]);
rt.set_zeros(cache, 4 * std::mem::size_of::<f32>());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
for (pos, value, expected) in [
(0, 10.0, [10.0, 0.0, 0.0, 0.0]),
@@ -1383,7 +1387,7 @@ fn test_load_safetensors_f32_survives_search_and_overrides_input_data() {
rt.set_data(weights, &[99.0, 99.0, 99.0]);
rt.set_data(bias, &[0.5, 1.0, -1.5]);
rt.load_safetensors(&cx, path.to_str().unwrap());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1448,7 +1452,7 @@ fn test_load_safetensors_converts_supported_float_dtypes() {
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
let mut rt = MetalRuntime::initialize(());
rt.load_safetensors(&cx, path.to_str().unwrap());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1475,7 +1479,7 @@ fn test_gather_noncontiguous_data_uses_data_shape() {
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
);
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1495,7 +1499,7 @@ fn test_scatter_into_nonzero_dest() {
rt.set_data(src, &[99.0]);
rt.set_data(indexes, &[2f32]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1522,7 +1526,7 @@ fn test_scatter_no_copy_remove_buffer_aliases_dest() {
rt.set_data(src, &[7.0, 8.0]);
rt.set_data(indexes, &[1.0, 3.0]);
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0, 50.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1551,7 +1555,7 @@ fn test_scatter_no_copy_handles_2d_destination() {
rt.set_data(src, &[9.0, 8.0]);
rt.set_data(indexes, &[2.0, 4.0]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1578,7 +1582,7 @@ fn test_scatter_no_copy_not_selected_when_dest_has_another_consumer() {
rt.set_data(src, &[99.0]);
rt.set_data(indexes, &[1.0]);
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
let kernels = rt.debug_kernel_ops();
assert!(
!kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
@@ -1605,7 +1609,7 @@ fn test_scatter_all_positions() {
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);
@@ -1624,7 +1628,7 @@ fn test_gather_preserves_data_dtype() {
let mut rt = MetalRuntime::initialize(());
rt.set_data(data, &[1.25, 2.5]);
rt.set_data(indexes, &[1.0]);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.allocate_intermediate_buffers(&cx.dyn_map);
rt.execute(&cx.dyn_map);

View File

@@ -167,7 +167,10 @@ mod tests {
let result = gather_rows(data, indices, 3).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
rt.set_data(
@@ -193,7 +196,10 @@ mod tests {
let result = scatter_rows(src, indices, dest, 3).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
rt.set_data(indices.id, vec![1, 3]);
@@ -219,7 +225,10 @@ mod tests {
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
@@ -272,7 +281,10 @@ mod tests {
let v_cache_new = v_cache_new.output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
rt.set_data(q.id, vec![1., 0., 1., 0.]);
@@ -345,7 +357,10 @@ mod tests {
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
// K cached at slot 0: [1, 0]
@@ -417,7 +432,10 @@ mod tests {
let attn_out = attn_out.output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Cache has 1 token at slot 0
let mut k_cache_data = vec![0.; num_slots * kv_dim];

View File

@@ -184,7 +184,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![1.0, 2.0, 3.0];
// Router strongly favors expert 0
@@ -239,7 +242,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![1.0, 1.0];
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
@@ -293,7 +299,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = vec![
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
@@ -350,7 +359,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = random_vec(in_dim);
let router_data = random_vec(in_dim * n_experts);
@@ -395,7 +407,10 @@ mod tests {
let output = moe.forward(input).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let input_data = random_vec(batch * in_dim);
let router_data = random_vec(in_dim * n_experts);

View File

@@ -53,6 +53,10 @@ fn env_usize(name: &str, default: usize) -> usize {
.unwrap_or(default)
}
fn search_options() -> CompileOptions {
CompileOptions::default().search_graph_limit(env_usize("SEARCH_ITERS", 5))
}
fn env_f32(name: &str, default: f32) -> f32 {
std::env::var(name)
.ok()
@@ -187,7 +191,7 @@ fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>
println!("Compiling text encoder...");
let t0 = Instant::now();
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
runtime = cx.search(runtime, search_options());
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
println!("Encoding prompt...");
@@ -345,10 +349,9 @@ fn run_full_pipeline(
{
use rand::SeedableRng;
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
let opts = luminal::graph::CompileOptions::new(env_usize("SEARCH_ITERS", 5));
runtime = cx.search_with_rng(runtime, opts, &mut rng);
runtime = cx.search_with_rng(runtime, search_options(), &mut rng);
} else {
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
runtime = cx.search(runtime, search_options());
}
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
@@ -415,7 +418,7 @@ fn run_full_pipeline(
let mut runtime = CudaRuntime::initialize(stream);
runtime.load_safetensors(&cx, vae_path.to_str().unwrap());
runtime.set_data(latent_in, vae_input);
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
runtime = cx.search(runtime, search_options());
runtime.execute(&cx.dyn_map);
let img = runtime.get_f32(out);
// VaeDecoder output is in roughly [-1, 1] range. Diffusers'

View File

@@ -720,6 +720,10 @@ mod tests {
out
}
fn one_search() -> CompileOptions {
CompileOptions::default().search_graph_limit(1)
}
#[test]
fn conv2d_bias_matches_reference() {
let mut cx = Graph::default();
@@ -747,7 +751,7 @@ mod tests {
);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
@@ -766,7 +770,7 @@ mod tests {
let expected = reference_nearest_upsample_2x(&input, 2, 3, 4);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.execute(&cx.dyn_map);
@@ -790,7 +794,7 @@ mod tests {
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
let mut rt = CudaRuntime::initialize(ctx.default_stream());
rt.set_data(input_t, input);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, one_search());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected);
@@ -821,7 +825,7 @@ mod tests {
);
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), one_search());
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
@@ -864,7 +868,7 @@ mod tests {
rt.set_data(input_t, input);
rt.set_data(weight_t, weight);
rt.set_data(bias_t, bias);
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, one_search());
rt.execute(&cx.dyn_map);
assert_close(&rt.get_f32(out.id), &expected);

View File

@@ -84,7 +84,8 @@ fn main() {
cx.set_dim('p', 0);
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(token_ids, (0..search_s as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);

View File

@@ -36,14 +36,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -68,114 +70,11 @@ pub struct Gemma {
impl Gemma {
pub fn init(cx: &mut Graph) -> Self {
let mut w = vec![];
for l in 0..LAYERS {
let is_local = (l + 1) % SLIDING_WINDOW_PATTERN != 0;
let up = cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let gate = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
w.push(GemmaLayer {
up,
gate,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
input_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.input_layernorm.weight"),
cx,
),
post_attention_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.post_attention_layernorm.weight"),
cx,
),
pre_feedforward_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.pre_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm: gemma_norm(
HIDDEN,
&format!("model.layers.{l}.post_feedforward_layernorm.weight"),
cx,
),
is_local,
rope_theta: if is_local {
ROPE_THETA_LOCAL
} else {
ROPE_THETA_GLOBAL
},
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
});
}
let lm_norm = gemma_norm(HIDDEN, "model.norm.weight", cx);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
lm_head,
layers: w,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| GemmaLayer::init(cx, l)).collect(),
lm_norm: gemma_norm(HIDDEN, "model.norm.weight", cx),
}
}
@@ -185,11 +84,7 @@ impl Gemma {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -226,6 +121,114 @@ struct GemmaLayer {
rope_scaling_factor: f32,
}
impl GemmaLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
let is_local = !(l + 1).is_multiple_of(SLIDING_WINDOW_PATTERN);
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
input_layernorm: layer_norm(cx, l, "input_layernorm"),
post_attention_layernorm: layer_norm(cx, l, "post_attention_layernorm"),
pre_feedforward_layernorm: layer_norm(cx, l, "pre_feedforward_layernorm"),
post_feedforward_layernorm: layer_norm(cx, l, "post_feedforward_layernorm"),
is_local,
rope_theta: if is_local {
ROPE_THETA_LOCAL
} else {
ROPE_THETA_GLOBAL
},
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
}
}
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = gemma_rotary_embeddings(
qk_norm(q, self.q_norm, N_HEADS),
pos_ids,
N_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let k_rope = gemma_rotary_embeddings(
qk_norm(k, self.k_norm, N_KV_HEADS),
pos_ids,
N_KV_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope,
k_rope,
v,
k_cache_in,
v_cache_in,
max_seq,
self.is_local,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = x + self.post_attention_layernorm.forward(attn_proj);
let x_ff = self.pre_feedforward_layernorm.forward(x);
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
.matmul(self.down.t());
(
x + self.post_feedforward_layernorm.forward(mlp_out),
k_cache_out,
v_cache_out,
)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
gemma_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
/// GELU using the identity: 0.5*x*(1+tanh(a)) = x*sigmoid(2*a)
/// This produces far fewer e-graph nodes than the tanh-based expansion.
#[allow(clippy::excessive_precision)]
@@ -363,59 +366,3 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl GemmaLayer {
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// QK-norm + RoPE
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
let q_rope = gemma_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let k_rope = gemma_rotary_embeddings(
k_normed,
pos_ids,
N_KV_HEADS,
self.rope_theta,
self.rope_scaling_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope,
k_rope,
v,
k_cache_in,
v_cache_in,
max_seq,
self.is_local,
);
// O projection + post-attention norm + residual
let attn_proj = attn_out.matmul(self.o_proj.t());
let attn_normed = self.post_attention_layernorm.forward(attn_proj);
let x = x + attn_normed;
// Pre-feedforward norm + MLP + post-feedforward norm + residual
let x_ff = self.pre_feedforward_layernorm.forward(x);
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
.matmul(self.down.t());
let mlp_normed = self.post_feedforward_layernorm.forward(mlp_out);
(x + mlp_normed, k_cache_out, v_cache_out)
}
}

View File

@@ -81,11 +81,10 @@ fn main() {
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_with_rng(
runtime,
CompileOptions::new(search_graphs).profile_timeout(Duration::from_secs(2)),
&mut rng,
);
let search_options = CompileOptions::default()
.search_graph_limit(search_graphs)
.profile_timeout(Duration::from_secs(2));
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
for layer in 0..LAYERS {
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);

View File

@@ -83,20 +83,16 @@ impl KVCache {
let mut v_caches = Vec::with_capacity(LAYERS);
for layer in 0..LAYERS {
let spec = layer_spec(layer);
let k = cx
.named_tensor(
format!("kv_cache.{layer}.k"),
(spec.num_kv_heads, max_seq, spec.head_dim),
)
.persist();
let v = cx
.named_tensor(
format!("kv_cache.{layer}.v"),
(spec.num_kv_heads, max_seq, spec.head_dim),
)
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{layer}.k"),
(spec.num_kv_heads, max_seq, spec.head_dim),
));
v_caches.push(persist(
cx,
format!("kv_cache.{layer}.v"),
(spec.num_kv_heads, max_seq, spec.head_dim),
));
}
Self {
k_caches,
@@ -115,169 +111,13 @@ pub struct Gemma4MoE {
impl Gemma4MoE {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = Vec::with_capacity(LAYERS);
for layer in 0..LAYERS {
let spec = layer_spec(layer);
let gate = cx
.named_tensor(
format!("model.layers.{layer}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let up = cx
.named_tensor(
format!("model.layers.{layer}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{layer}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.q_proj.weight"),
(spec.q_dim, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.k_proj.weight"),
(spec.kv_dim, HIDDEN),
)
.persist();
let v_proj = spec.has_v_proj.then(|| {
cx.named_tensor(
format!("model.layers.{layer}.self_attn.v_proj.weight"),
(spec.kv_dim, HIDDEN),
)
.persist()
});
let o_proj = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.o_proj.weight"),
(HIDDEN, spec.q_dim),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.q_norm.weight"),
spec.head_dim,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{layer}.self_attn.k_norm.weight"),
spec.head_dim,
)
.persist();
let layer_scalar = cx
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
.persist();
let router_scale = cx
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
.persist();
let router_proj = cx
.named_tensor(
format!("model.layers.{layer}.router.proj.weight"),
(NUM_EXPERTS, HIDDEN),
)
.persist();
let per_expert_scale = cx
.named_tensor(
format!("model.layers.{layer}.router.per_expert_scale"),
NUM_EXPERTS,
)
.persist();
let gate_up_weights = cx
.named_tensor(
format!("model.layers.{layer}.experts.gate_up_proj"),
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.persist()
.as_dtype(DType::Bf16);
let down_weights = cx
.named_tensor(
format!("model.layers.{layer}.experts.down_proj"),
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.persist()
.as_dtype(DType::Bf16);
layers.push(Gemma4Layer {
spec,
gate,
up,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
layer_scalar,
input_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.input_layernorm.weight"),
cx,
),
post_attention_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
cx,
),
pre_feedforward_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
cx,
),
post_feedforward_layernorm_1: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
cx,
),
post_feedforward_layernorm_2: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
cx,
),
pre_feedforward_layernorm_2: gemma4_norm(
HIDDEN,
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
cx,
),
moe: Gemma4SparseMoE {
router_scale,
router_proj,
per_expert_scale,
gate_up_weights,
down_weights,
},
});
}
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
Self {
embedding,
lm_head,
layers,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS)
.map(|layer| Gemma4Layer::init(cx, layer))
.collect(),
lm_norm: gemma4_norm(HIDDEN, "model.norm.weight", cx),
}
}
@@ -287,11 +127,7 @@ impl Gemma4MoE {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (layer_idx, layer) in self.layers.iter().enumerate() {
@@ -342,6 +178,164 @@ struct Gemma4SparseMoE {
down_weights: GraphTensor,
}
impl Gemma4Layer {
fn init(cx: &mut Graph, layer: usize) -> Self {
let spec = layer_spec(layer);
Self {
spec,
gate: layer_weight(cx, layer, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
up: layer_weight(cx, layer, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, layer, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, layer, "self_attn.q_proj", (spec.q_dim, HIDDEN)),
k_proj: layer_weight(cx, layer, "self_attn.k_proj", (spec.kv_dim, HIDDEN)),
v_proj: spec
.has_v_proj
.then(|| layer_weight(cx, layer, "self_attn.v_proj", (spec.kv_dim, HIDDEN))),
o_proj: layer_weight(cx, layer, "self_attn.o_proj", (HIDDEN, spec.q_dim)),
q_norm: layer_weight(cx, layer, "self_attn.q_norm", spec.head_dim),
k_norm: layer_weight(cx, layer, "self_attn.k_norm", spec.head_dim),
layer_scalar: layer_tensor(cx, layer, "layer_scalar", HIDDEN),
input_layernorm: layer_norm(cx, layer, "input_layernorm"),
post_attention_layernorm: layer_norm(cx, layer, "post_attention_layernorm"),
pre_feedforward_layernorm: layer_norm(cx, layer, "pre_feedforward_layernorm"),
post_feedforward_layernorm: layer_norm(cx, layer, "post_feedforward_layernorm"),
post_feedforward_layernorm_1: layer_norm(cx, layer, "post_feedforward_layernorm_1"),
post_feedforward_layernorm_2: layer_norm(cx, layer, "post_feedforward_layernorm_2"),
pre_feedforward_layernorm_2: layer_norm(cx, layer, "pre_feedforward_layernorm_2"),
moe: Gemma4SparseMoE {
router_scale: layer_tensor(cx, layer, "router.scale", HIDDEN),
router_proj: layer_weight(cx, layer, "router.proj", (NUM_EXPERTS, HIDDEN)),
per_expert_scale: layer_tensor(cx, layer, "router.per_expert_scale", NUM_EXPERTS),
gate_up_weights: layer_tensor(
cx,
layer,
"experts.gate_up_proj",
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.as_dtype(DType::Bf16),
down_weights: layer_tensor(
cx,
layer,
"experts.down_proj",
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.as_dtype(DType::Bf16),
},
}
}
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let residual = x;
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k_base = x_attn.matmul(self.k_proj.t());
let v_base = if let Some(v_proj) = self.v_proj {
x_attn.matmul(v_proj.t())
} else {
k_base
};
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
let k_normed = qk_norm(
k_base,
self.k_norm,
self.spec.num_kv_heads,
self.spec.head_dim,
);
let v_normed = value_norm(v_base, self.spec.head_dim);
let q_rope = gemma4_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let k_rope = gemma4_rotary_embeddings(
k_normed,
pos_ids,
self.spec.num_kv_heads,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = residual + self.post_attention_layernorm.forward(attn_proj);
let dense_ff = dense_ffn(
self.pre_feedforward_layernorm.forward(x),
self.gate,
self.up,
self.down,
);
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
let moe_out = self
.moe
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
let x = x + ff_out;
let x = x * self
.layer_scalar
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
(x, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_tensor(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
}
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
gemma4_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
}
@@ -505,81 +499,6 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl Gemma4Layer {
pub fn forward(
&self,
x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let residual = x;
let x_attn = self.input_layernorm.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k_base = x_attn.matmul(self.k_proj.t());
let v_base = if let Some(v_proj) = self.v_proj {
x_attn.matmul(v_proj.t())
} else {
k_base
};
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
let k_normed = qk_norm(
k_base,
self.k_norm,
self.spec.num_kv_heads,
self.spec.head_dim,
);
let v_normed = value_norm(v_base, self.spec.head_dim);
let q_rope = gemma4_rotary_embeddings(
q_normed,
pos_ids,
N_HEADS,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let k_rope = gemma4_rotary_embeddings(
k_normed,
pos_ids,
self.spec.num_kv_heads,
self.spec.head_dim,
self.spec.rope_theta,
self.spec.partial_rotary_factor,
);
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
);
let attn_proj = attn_out.matmul(self.o_proj.t());
let x = residual + self.post_attention_layernorm.forward(attn_proj);
let dense_ff = dense_ffn(
self.pre_feedforward_layernorm.forward(x),
self.gate,
self.up,
self.down,
);
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
let moe_out = self
.moe
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
let x = x + ff_out;
let x = x * self
.layer_scalar
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
(x, k_cache_out, v_cache_out)
}
}
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
}

View File

@@ -338,13 +338,11 @@ fn main() {
println!(" Search trials: {SEARCH_TRIALS}");
println!(" Search keep-best: {SEARCH_KEEP_BEST}");
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_with_rng(
runtime,
CompileOptions::new(SEARCH_GRAPHS)
.trials(SEARCH_TRIALS)
.keep_best(SEARCH_KEEP_BEST),
&mut rng,
);
let search_options = CompileOptions::default()
.search_graph_limit(SEARCH_GRAPHS)
.trials(SEARCH_TRIALS)
.keep_best(SEARCH_KEEP_BEST);
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
println!(
" Search/compile: {:.2} s",
compile_start.elapsed().as_secs_f64()

View File

@@ -111,125 +111,18 @@ impl Llama {
config: LlamaConfig,
fp8_linears: bool,
) -> Self {
let mut layers = Vec::with_capacity(config.layers);
for l in 0..config.layers {
layers.push(LlamaLayer {
config,
up: linear_weight(
cx,
format!("model.layers.{l}.mlp.up_proj"),
(config.intermediate, config.hidden),
fp8_linears,
),
up_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.up_proj"),
fp8_linears,
),
gate: linear_weight(
cx,
format!("model.layers.{l}.mlp.gate_proj"),
(config.intermediate, config.hidden),
fp8_linears,
),
gate_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.gate_proj"),
fp8_linears,
),
down: linear_weight(
cx,
format!("model.layers.{l}.mlp.down_proj"),
(config.hidden, config.intermediate),
fp8_linears,
),
down_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.mlp.down_proj"),
fp8_linears,
),
q_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.q_proj"),
(config.hidden, config.hidden),
fp8_linears,
),
q_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.q_proj"),
fp8_linears,
),
k_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.k_proj"),
(config.kv_dim(), config.hidden),
fp8_linears,
),
k_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.k_proj"),
fp8_linears,
),
v_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.v_proj"),
(config.kv_dim(), config.hidden),
fp8_linears,
),
v_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.v_proj"),
fp8_linears,
),
o_proj: linear_weight(
cx,
format!("model.layers.{l}.self_attn.o_proj"),
(config.hidden, config.hidden),
fp8_linears,
),
o_proj_scales: fp8_linear_scales(
cx,
format!("model.layers.{l}.self_attn.o_proj"),
fp8_linears,
),
attn_rms: LayerNorm::new(
config.hidden,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
1e-5,
cx,
),
mlp_rms: LayerNorm::new(
config.hidden,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
1e-5,
cx,
),
});
}
Self {
config,
embedding: cx
.named_tensor(
"model.embed_tokens.weight",
(config.vocab_size, config.hidden),
)
.persist(),
layers,
lm_head: cx
.named_tensor("lm_head.weight", (config.vocab_size, config.hidden))
.persist(),
lm_norm: LayerNorm::new(
config.hidden,
Some("model.norm.weight"),
None,
false,
1e-5,
embedding: persist(
cx,
"model.embed_tokens.weight",
(config.vocab_size, config.hidden),
),
layers: (0..config.layers)
.map(|l| LlamaLayer::init(cx, l, config, fp8_linears))
.collect(),
lm_head: persist(cx, "lm_head.weight", (config.vocab_size, config.hidden)),
lm_norm: rms_norm(cx, config.hidden, "model.norm.weight"),
}
}
@@ -243,12 +136,7 @@ impl Llama {
attn_mask: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = input.dims1();
let hidden = self.config.hidden;
let mut x = self.embedding.gather(
(input * hidden).expand_dim(1, hidden)
+ input.graph().arange(hidden).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, input, self.config.hidden);
let mut cache_outputs = Vec::with_capacity(self.config.layers);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -311,6 +199,170 @@ struct Fp8LinearScales {
weight: GraphTensor,
}
impl LlamaLayer {
fn init(cx: &mut Graph, l: usize, config: LlamaConfig, fp8: bool) -> Self {
Self {
config,
up: layer_linear_weight(
cx,
l,
"mlp.up_proj",
(config.intermediate, config.hidden),
fp8,
),
up_scales: layer_linear_scales(cx, l, "mlp.up_proj", fp8),
gate: layer_linear_weight(
cx,
l,
"mlp.gate_proj",
(config.intermediate, config.hidden),
fp8,
),
gate_scales: layer_linear_scales(cx, l, "mlp.gate_proj", fp8),
down: layer_linear_weight(
cx,
l,
"mlp.down_proj",
(config.hidden, config.intermediate),
fp8,
),
down_scales: layer_linear_scales(cx, l, "mlp.down_proj", fp8),
q_proj: layer_linear_weight(
cx,
l,
"self_attn.q_proj",
(config.hidden, config.hidden),
fp8,
),
q_proj_scales: layer_linear_scales(cx, l, "self_attn.q_proj", fp8),
k_proj: layer_linear_weight(
cx,
l,
"self_attn.k_proj",
(config.kv_dim(), config.hidden),
fp8,
),
k_proj_scales: layer_linear_scales(cx, l, "self_attn.k_proj", fp8),
v_proj: layer_linear_weight(
cx,
l,
"self_attn.v_proj",
(config.kv_dim(), config.hidden),
fp8,
),
v_proj_scales: layer_linear_scales(cx, l, "self_attn.v_proj", fp8),
o_proj: layer_linear_weight(
cx,
l,
"self_attn.o_proj",
(config.hidden, config.hidden),
fp8,
),
o_proj_scales: layer_linear_scales(cx, l, "self_attn.o_proj", fp8),
attn_rms: rms_norm(
cx,
config.hidden,
format!("model.layers.{l}.input_layernorm.weight"),
),
mlp_rms: rms_norm(
cx,
config.hidden,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
let (attn_out, k_cache_out, v_cache_out) = attention(
AttentionInputs {
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
},
self.config,
);
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
let x_mlp = self.mlp_rms.forward(x);
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
* linear_matmul(x_mlp, self.up, self.up_scales);
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
(x + mlp_out, k_cache_out, v_cache_out)
}
#[allow(dead_code)]
fn parameter_tensors(&self) -> Vec<GraphTensor> {
let mut tensors = vec![
self.up,
self.gate,
self.down,
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj,
];
for scales in [
self.up_scales,
self.gate_scales,
self.down_scales,
self.q_proj_scales,
self.k_proj_scales,
self.v_proj_scales,
self.o_proj_scales,
]
.into_iter()
.flatten()
{
tensors.push(scales.input);
tensors.push(scales.weight);
}
if let Some(weight) = self.attn_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.attn_rms.bias {
tensors.push(bias);
}
if let Some(weight) = self.mlp_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.mlp_rms.bias {
tensors.push(bias);
}
tensors
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn linear_weight(
cx: &mut Graph,
prefix: impl ToString,
@@ -325,6 +377,16 @@ fn linear_weight(
}
}
fn layer_linear_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
fp8: bool,
) -> GraphTensor {
linear_weight(cx, format!("model.layers.{layer}.{suffix}"), shape, fp8)
}
fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option<Fp8LinearScales> {
if !fp8 {
return None;
@@ -340,6 +402,27 @@ fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option
})
}
fn layer_linear_scales(
cx: &mut Graph,
layer: usize,
suffix: &str,
fp8: bool,
) -> Option<Fp8LinearScales> {
fp8_linear_scales(cx, format!("model.layers.{layer}.{suffix}"), fp8)
}
fn rms_norm(cx: &mut Graph, dim: usize, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(dim, Some(&weight_name.to_string()), None, false, 1e-5, cx)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor, hidden: usize) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * hidden).expand_dim(1, hidden)
+ token_ids.graph().arange(hidden).expand_dim(0, seq),
)
}
fn expand_scalar(scale: GraphTensor, like: GraphTensor) -> GraphTensor {
scale.expand_rhs(like.dims())
}
@@ -443,87 +526,3 @@ fn attention(
(attn_out, k_cache_out, v_cache_out)
}
impl LlamaLayer {
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
let (attn_out, k_cache_out, v_cache_out) = attention(
AttentionInputs {
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
},
self.config,
);
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
let x_mlp = self.mlp_rms.forward(x);
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
* linear_matmul(x_mlp, self.up, self.up_scales);
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
(x + mlp_out, k_cache_out, v_cache_out)
}
#[allow(dead_code)]
fn parameter_tensors(&self) -> Vec<GraphTensor> {
let mut tensors = vec![
self.up,
self.gate,
self.down,
self.q_proj,
self.k_proj,
self.v_proj,
self.o_proj,
];
for scales in [
self.up_scales,
self.gate_scales,
self.down_scales,
self.q_proj_scales,
self.k_proj_scales,
self.v_proj_scales,
self.o_proj_scales,
]
.into_iter()
.flatten()
{
tensors.push(scales.input);
tensors.push(scales.weight);
}
if let Some(weight) = self.attn_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.attn_rms.bias {
tensors.push(bias);
}
if let Some(weight) = self.mlp_rms.weight {
tensors.push(weight);
}
if let Some(bias) = self.mlp_rms.bias {
tensors.push(bias);
}
tensors
}
}

View File

@@ -241,7 +241,8 @@ fn main() {
runtime.set_data(scatter_idx_t, vec![0i32; search_s]);
runtime.set_data(gather_idx_t, vec![0i32; search_c]);
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
// Re-initialize KV cache after search (search consumes buffers)
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();

View File

@@ -25,8 +25,8 @@ pub struct PagedKVCache {
impl PagedKVCache {
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
let mut k_caches = vec![];
let mut v_caches = vec![];
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM)));
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM)));
@@ -44,78 +44,11 @@ pub struct Llama {
impl Llama {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = vec![];
for l in 0..LAYERS {
layers.push(LlamaLayer {
up: cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
gate: cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist(),
down: cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist(),
q_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
k_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
v_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist(),
o_proj: cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, HIDDEN),
)
.persist(),
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
1e-5,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
1e-5,
cx,
),
});
}
Self {
embedding: cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
layers,
lm_head: cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist(),
lm_norm: LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx),
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| LlamaLayer::init(cx, l)).collect(),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
lm_norm: rms_norm(cx, "model.norm.weight"),
}
}
@@ -141,12 +74,8 @@ impl Llama {
attn_mask: GraphTensor,
kv_cache: &PagedKVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = input.dims1();
let mut x = self.embedding.gather(
(input * HIDDEN).expand_dim(1, HIDDEN)
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut cache_outputs = vec![];
let mut x = token_embedding(self.embedding, input);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
x,
@@ -177,6 +106,99 @@ struct LlamaLayer {
mlp_rms: LayerNorm,
}
impl LlamaLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (HIDDEN, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, HIDDEN)),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = llama_rotary_embeddings(q, q_pos);
let k_rope = llama_rotary_embeddings(k, q_pos);
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
1e-5,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
@@ -264,44 +286,3 @@ fn paged_attention(
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
(attn_out, k_cache_out, v_cache_out)
}
impl LlamaLayer {
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
mut x: GraphTensor,
q_pos: GraphTensor,
scatter_idx: GraphTensor,
gather_idx: GraphTensor,
attn_mask: GraphTensor,
k_cache: GraphTensor,
v_cache: GraphTensor,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// Apply RoPE before scattering into cache
let q_rope = llama_rotary_embeddings(q, q_pos);
let k_rope = llama_rotary_embeddings(k, q_pos);
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
q_rope,
k_rope,
v,
k_cache,
v_cache,
scatter_idx,
gather_idx,
attn_mask,
);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}

View File

@@ -188,7 +188,8 @@ where
cx.set_dim('p', 0);
runtime.set_i32_data(input.id, vec![1; search_s]);
runtime.set_i32_data(token_ids.id, (0..search_s as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, CompileOptions::new(config.search_graphs));
let search_options = CompileOptions::default().search_graph_limit(config.search_graphs);
runtime = cx.search(runtime, search_options);
for i in 0..config.layers {
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);

View File

@@ -34,14 +34,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(layers);
let mut v_caches = Vec::with_capacity(layers);
for l in 0..layers {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -63,105 +65,10 @@ impl Qwen {
layers <= LAYERS,
"requested {layers} layers, but model has {LAYERS}"
);
let mut w = vec![];
for l in 0..layers {
let up = cx
.named_tensor(
format!("model.layers.{l}.mlp.up_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let gate = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_proj.weight"),
(INTERMEDIATE, HIDDEN),
)
.persist();
let down = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_proj.weight"),
(HIDDEN, INTERMEDIATE),
)
.persist();
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
w.push(QwenLayer {
up,
gate,
down,
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
});
}
let lm_norm = LayerNorm::new(
HIDDEN,
Some("model.norm.weight"),
None,
false,
RMS_NORM_EPS,
cx,
);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
layers: w,
lm_norm,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..layers).map(|l| QwenLayer::init(cx, l)).collect(),
lm_norm: rms_norm(cx, "model.norm.weight"),
}
}
@@ -172,11 +79,7 @@ impl Qwen {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(self.layers.len());
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -209,6 +112,90 @@ struct QwenLayer {
mlp_rms: LayerNorm,
}
impl QwenLayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
}
}
pub fn forward(
&self,
mut x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
let q_rope = qwen_rotary_embeddings(qk_norm(q, self.q_norm, N_HEADS), pos_ids, N_HEADS);
let k_rope =
qwen_rotary_embeddings(qk_norm(k, self.k_norm, N_KV_HEADS), pos_ids, N_KV_HEADS);
let (attn_out, k_cache_out, v_cache_out) =
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
RMS_NORM_EPS,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
/// Per-head RMS normalization for QK-norm.
/// Input: [seq, dim] where dim = n_heads * HEAD_DIM
/// split_dims to [seq, n_heads, HEAD_DIM], RMS norm over last axis, multiply by weight, merge back.
@@ -331,36 +318,3 @@ fn hlir_attention(
(out, k_cache_out, v_cache_out)
}
impl QwenLayer {
pub fn forward(
&self,
mut x: GraphTensor,
pos_ids: GraphTensor,
k_cache_in: GraphTensor,
v_cache_in: GraphTensor,
max_seq: usize,
) -> (GraphTensor, GraphTensor, GraphTensor) {
let x_attn = self.attn_rms.forward(x);
let q = x_attn.matmul(self.q_proj.t());
let k = x_attn.matmul(self.k_proj.t());
let v = x_attn.matmul(self.v_proj.t());
// QK-norm: per-head RMS normalization
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
// RoPE
let q_rope = qwen_rotary_embeddings(q_normed, pos_ids, N_HEADS);
let k_rope = qwen_rotary_embeddings(k_normed, pos_ids, N_KV_HEADS);
let (attn_out, k_cache_out, v_cache_out) =
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
x += attn_out.matmul(self.o_proj.t());
let x_mlp = self.mlp_rms.forward(x);
let mlp_out =
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
(x + mlp_out, k_cache_out, v_cache_out)
}
}

View File

@@ -82,7 +82,8 @@ fn main() {
runtime.set_data(input, vec![1; search_s]);
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
runtime = cx.search_with_rng(runtime, CompileOptions::new(search_graphs), &mut rng);
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
for i in 0..LAYERS {
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);

View File

@@ -29,17 +29,19 @@ pub struct KVCache {
impl KVCache {
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
let mut k_caches = vec![];
let mut v_caches = vec![];
let mut k_caches = Vec::with_capacity(LAYERS);
let mut v_caches = Vec::with_capacity(LAYERS);
for l in 0..LAYERS {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_KV_HEADS, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -58,111 +60,11 @@ pub struct Qwen3MoE {
impl Qwen3MoE {
pub fn init(cx: &mut Graph) -> Self {
let mut layers = vec![];
for l in 0..LAYERS {
let q_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_proj.weight"),
(Q_DIM, HIDDEN),
)
.persist();
let k_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let v_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.v_proj.weight"),
(KV_DIM, HIDDEN),
)
.persist();
let o_proj = cx
.named_tensor(
format!("model.layers.{l}.self_attn.o_proj.weight"),
(HIDDEN, Q_DIM),
)
.persist();
let q_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.q_norm.weight"),
HEAD_DIM,
)
.persist();
let k_norm = cx
.named_tensor(
format!("model.layers.{l}.self_attn.k_norm.weight"),
HEAD_DIM,
)
.persist();
let router = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate.weight"),
(NUM_EXPERTS, HIDDEN),
)
.persist();
let gate_up_weights = cx
.named_tensor(
format!("model.layers.{l}.mlp.gate_up_weights"),
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.persist();
let down_weights = cx
.named_tensor(
format!("model.layers.{l}.mlp.down_weights"),
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.persist();
layers.push(Qwen3MoELayer {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
attn_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.input_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
mlp_rms: LayerNorm::new(
HIDDEN,
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
None,
false,
RMS_NORM_EPS,
cx,
),
moe: QwenMoE {
router,
gate_up_weights: gate_up_weights.as_dtype(DType::Bf16),
down_weights: down_weights.as_dtype(DType::Bf16),
},
});
}
let lm_norm = LayerNorm::new(
HIDDEN,
Some("model.norm.weight"),
None,
false,
RMS_NORM_EPS,
cx,
);
let embedding = cx
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
.persist();
let lm_head = cx
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
.persist();
Self {
embedding,
layers,
lm_norm,
lm_head,
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
layers: (0..LAYERS).map(|l| Qwen3MoELayer::init(cx, l)).collect(),
lm_norm: rms_norm(cx, "model.norm.weight"),
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
}
}
@@ -172,11 +74,7 @@ impl Qwen3MoE {
pos_ids: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
let mut x = self.embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
);
let mut x = token_embedding(self.embedding, token_ids);
let mut cache_outputs = Vec::with_capacity(LAYERS);
for (i, layer) in self.layers.iter().enumerate() {
let (x_new, k_out, v_out) = layer.forward(
@@ -214,6 +112,39 @@ struct QwenMoE {
}
impl Qwen3MoELayer {
fn init(cx: &mut Graph, l: usize) -> Self {
Self {
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
mlp_rms: rms_norm(
cx,
format!("model.layers.{l}.post_attention_layernorm.weight"),
),
moe: QwenMoE {
router: layer_weight(cx, l, "mlp.gate", (NUM_EXPERTS, HIDDEN)),
gate_up_weights: layer_tensor(
cx,
l,
"mlp.gate_up_weights",
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
)
.as_dtype(DType::Bf16),
down_weights: layer_tensor(
cx,
l,
"mlp.down_weights",
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
)
.as_dtype(DType::Bf16),
},
}
}
pub fn forward(
&self,
mut x: GraphTensor,
@@ -247,6 +178,51 @@ impl Qwen3MoELayer {
}
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
fn layer_tensor(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
}
fn layer_weight(
cx: &mut Graph,
layer: usize,
suffix: &str,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
}
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
LayerNorm::new(
HIDDEN,
Some(&weight_name.to_string()),
None,
false,
RMS_NORM_EPS,
cx,
)
}
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
let seq = token_ids.dims1();
embedding.gather(
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
)
}
impl QwenMoE {
fn forward(&self, x: GraphTensor) -> GraphTensor {
let n = x.dims().len(); // 2 for [s, H]

View File

@@ -12,7 +12,7 @@ fn main() {
// Compile
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default());
// Set input tensors
rt.set_data(a, vec![1.0, 2.0, 3.0]);

View File

@@ -96,7 +96,8 @@ fn main() {
cx.set_dim('p', 0);
runtime.set_data(input, vec![1i32; max_prefill]);
runtime.set_data(pos_ids, (0..max_prefill as i32).collect::<Vec<_>>());
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
// Reset the KV caches and re-set the mel after search (which executes test runs).
for i in 0..N_TEXT_LAYER {

View File

@@ -33,6 +33,14 @@ fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
x.matmul(w.t())
}
fn persist(
cx: &mut Graph,
name: impl ToString,
shape: impl luminal::prelude::ToShape,
) -> GraphTensor {
cx.named_tensor(name, shape).persist()
}
/// 1D convolution with bias. Input: (ch_in, length). Weight: (ch_out, ch_in*kernel)
/// (HF stores it as (ch_out, ch_in, kernel) which flat-loads identically). Output: (ch_out, out_length).
fn conv1d_bias(
@@ -90,27 +98,13 @@ struct AttentionWeights {
impl AttentionWeights {
fn new(prefix: &str, dim: usize, cx: &mut Graph) -> Self {
Self {
q_proj: cx
.named_tensor(format!("{prefix}.q_proj.weight"), (dim, dim))
.persist(),
q_bias: cx
.named_tensor(format!("{prefix}.q_proj.bias"), dim)
.persist(),
k_proj: cx
.named_tensor(format!("{prefix}.k_proj.weight"), (dim, dim))
.persist(),
v_proj: cx
.named_tensor(format!("{prefix}.v_proj.weight"), (dim, dim))
.persist(),
v_bias: cx
.named_tensor(format!("{prefix}.v_proj.bias"), dim)
.persist(),
out_proj: cx
.named_tensor(format!("{prefix}.out_proj.weight"), (dim, dim))
.persist(),
out_bias: cx
.named_tensor(format!("{prefix}.out_proj.bias"), dim)
.persist(),
q_proj: persist(cx, format!("{prefix}.q_proj.weight"), (dim, dim)),
q_bias: persist(cx, format!("{prefix}.q_proj.bias"), dim),
k_proj: persist(cx, format!("{prefix}.k_proj.weight"), (dim, dim)),
v_proj: persist(cx, format!("{prefix}.v_proj.weight"), (dim, dim)),
v_bias: persist(cx, format!("{prefix}.v_proj.bias"), dim),
out_proj: persist(cx, format!("{prefix}.out_proj.weight"), (dim, dim)),
out_bias: persist(cx, format!("{prefix}.out_proj.bias"), dim),
}
}
}
@@ -125,6 +119,14 @@ fn merge_heads(x: GraphTensor) -> GraphTensor {
x.transpose(0, 1).merge_dims(1, 2)
}
fn embedding_lookup(embedding: GraphTensor, ids: GraphTensor) -> GraphTensor {
let seq = ids.dims1();
embedding.gather(
(ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
)
}
/// Encoder self-attention (full, non-causal). Input/output shape (seq, dim).
fn encoder_self_attention(x: GraphTensor, w: &AttentionWeights) -> GraphTensor {
let q = linear_with_bias(x, w.q_proj, w.q_bias);
@@ -239,18 +241,10 @@ impl EncoderLayer {
N_AUDIO_STATE,
cx,
),
fc1: cx
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE))
.persist(),
fc1_b: cx
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
.persist(),
fc2: cx
.named_tensor(format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM))
.persist(),
fc2_b: cx
.named_tensor(format!("{prefix}.fc2.bias"), N_AUDIO_STATE)
.persist(),
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE)),
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM)),
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_AUDIO_STATE),
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_AUDIO_STATE, cx),
}
}
@@ -295,18 +289,10 @@ impl DecoderLayer {
N_TEXT_STATE,
cx,
),
fc1: cx
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE))
.persist(),
fc1_b: cx
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
.persist(),
fc2: cx
.named_tensor(format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM))
.persist(),
fc2_b: cx
.named_tensor(format!("{prefix}.fc2.bias"), N_TEXT_STATE)
.persist(),
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE)),
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM)),
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_TEXT_STATE),
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_TEXT_STATE, cx),
}
}
@@ -346,14 +332,16 @@ impl KVCache {
let mut k_caches = Vec::with_capacity(N_TEXT_LAYER);
let mut v_caches = Vec::with_capacity(N_TEXT_LAYER);
for l in 0..N_TEXT_LAYER {
let k = cx
.named_tensor(format!("kv_cache.{l}.k"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
.persist();
let v = cx
.named_tensor(format!("kv_cache.{l}.v"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
.persist();
k_caches.push(k);
v_caches.push(v);
k_caches.push(persist(
cx,
format!("kv_cache.{l}.k"),
(N_TEXT_HEAD, max_seq, HEAD_DIM),
));
v_caches.push(persist(
cx,
format!("kv_cache.{l}.v"),
(N_TEXT_HEAD, max_seq, HEAD_DIM),
));
}
Self {
k_caches,
@@ -376,27 +364,23 @@ pub struct WhisperEncoder {
impl WhisperEncoder {
pub fn init(cx: &mut Graph) -> Self {
Self {
conv1_w: cx
.named_tensor("model.encoder.conv1.weight", (N_AUDIO_STATE, N_MELS * 3))
.persist(),
conv1_b: cx
.named_tensor("model.encoder.conv1.bias", N_AUDIO_STATE)
.persist(),
conv2_w: cx
.named_tensor(
"model.encoder.conv2.weight",
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
)
.persist(),
conv2_b: cx
.named_tensor("model.encoder.conv2.bias", N_AUDIO_STATE)
.persist(),
positional_embedding: cx
.named_tensor(
"model.encoder.embed_positions.weight",
(N_AUDIO_CTX, N_AUDIO_STATE),
)
.persist(),
conv1_w: persist(
cx,
"model.encoder.conv1.weight",
(N_AUDIO_STATE, N_MELS * 3),
),
conv1_b: persist(cx, "model.encoder.conv1.bias", N_AUDIO_STATE),
conv2_w: persist(
cx,
"model.encoder.conv2.weight",
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
),
conv2_b: persist(cx, "model.encoder.conv2.bias", N_AUDIO_STATE),
positional_embedding: persist(
cx,
"model.encoder.embed_positions.weight",
(N_AUDIO_CTX, N_AUDIO_STATE),
),
layers: (0..N_AUDIO_LAYER)
.map(|i| EncoderLayer::new(i, cx))
.collect(),
@@ -427,15 +411,16 @@ pub struct WhisperDecoder {
impl WhisperDecoder {
pub fn init(cx: &mut Graph) -> Self {
Self {
embed_tokens: cx
.named_tensor("model.decoder.embed_tokens.weight", (N_VOCAB, N_TEXT_STATE))
.persist(),
embed_positions: cx
.named_tensor(
"model.decoder.embed_positions.weight",
(N_TEXT_CTX, N_TEXT_STATE),
)
.persist(),
embed_tokens: persist(
cx,
"model.decoder.embed_tokens.weight",
(N_VOCAB, N_TEXT_STATE),
),
embed_positions: persist(
cx,
"model.decoder.embed_positions.weight",
(N_TEXT_CTX, N_TEXT_STATE),
),
layers: (0..N_TEXT_LAYER)
.map(|i| DecoderLayer::new(i, cx))
.collect(),
@@ -450,18 +435,8 @@ impl WhisperDecoder {
xa: GraphTensor,
kv_cache: &KVCache,
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
let seq = token_ids.dims1();
// Token embedding gather
let mut x = self.embed_tokens.gather(
(token_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ token_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
);
// Positional embedding gather (using pos_ids)
let pos_emb = self.embed_positions.gather(
(pos_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
+ pos_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
);
x += pos_emb;
let mut x = embedding_lookup(self.embed_tokens, token_ids);
x += embedding_lookup(self.embed_positions, pos_ids);
let mut cache_outputs = Vec::with_capacity(N_TEXT_LAYER);
for (i, layer) in self.layers.iter().enumerate() {

View File

@@ -335,7 +335,8 @@ fn main() {
println!("Compiling (search_graphs={search_graphs})...");
let t0 = Instant::now();
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
runtime = cx.search(runtime, search_options);
println!(" search took {:?}", t0.elapsed());
// Re-set anchors/strides/dfl/img after search (search may consume the inputs)

View File

@@ -174,7 +174,10 @@ pub fn compile_backend<Rt: Runtime + 'static>(
}
// Search
let mut rt = graph.search(rt, CompileOptions::new(args.search_iters));
let mut rt = graph.search(
rt,
CompileOptions::default().search_graph_limit(args.search_iters),
);
// Rebuild label map after search (graph may have changed)
let label_map = build_label_map(graph);

View File

@@ -463,7 +463,10 @@ pub(super) mod tests {
let c = func(a, b).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let lhs_values = lhs_transform(random_vec(a_shape.iter().copied().product()));
let rhs_values = rhs_transform(random_vec(b_shape.iter().copied().product()));

View File

@@ -968,7 +968,10 @@ mod tests {
let zeros = cx.iota(Expression::from(0usize), 6);
let inv = values.scatter(perm, zeros).cast(DType::F32).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(data.id, vec![0., 1., 2., 3., 4., 5.]);
rt.set_data(indexes.id, vec![5, 0, 3, 2]);
rt.set_data(perm.id, vec![3, 2, 4, 1, 5, 0]);
@@ -985,7 +988,10 @@ mod tests {
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![10., 20., 30.]);
rt.set_data(indexes.id, vec![1, 3, 4]);
rt.set_data(dest.id, vec![0., 0., 0., 0., 0.]);
@@ -1001,7 +1007,10 @@ mod tests {
let dest = cx.tensor(5);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![99.]);
rt.set_data(indexes.id, vec![2]);
rt.set_data(dest.id, vec![1., 2., 3., 4., 5.]);
@@ -1017,7 +1026,10 @@ mod tests {
let dest = cx.tensor(4);
let result = src.scatter(indexes, dest).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(src.id, vec![40., 30., 20., 10.]);
rt.set_data(indexes.id, vec![3, 2, 1, 0]);
rt.set_data(dest.id, vec![1., 2., 3., 4.]);
@@ -1045,7 +1057,10 @@ mod tests {
let repeated = (a.repeat((2, 2)) * 1.0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(a.id, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
rt.execute(&cx.dyn_map);

View File

@@ -126,7 +126,10 @@ mod tests {
let b = func(&mut cx).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.execute(&cx.dyn_map);
@@ -211,7 +214,10 @@ mod tests {
let stacked = cx.stack(&[a, b, c], 0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let a_data = random_vec(6);
let b_data = random_vec(6);

View File

@@ -493,7 +493,10 @@ pub(super) mod tests {
let b = func(a).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
let v = random_vec(shape.iter().copied().product());
rt.set_data(a.id, v.clone());

View File

@@ -137,7 +137,9 @@ impl DimBucket {
/// Use the builder pattern to configure search parameters:
/// ```
/// use luminal::prelude::CompileOptions;
/// let opts = CompileOptions::new(5)
/// let opts = CompileOptions::default()
/// .search_graph_limit(5)
/// .search_time_limit(std::time::Duration::from_secs(30))
/// .generation_size(50)
/// .mutations(40)
/// .trials(15);
@@ -146,6 +148,8 @@ impl DimBucket {
pub struct CompileOptions {
/// Maximum number of graphs to evaluate during search.
pub limit: usize,
/// Maximum wall-clock time to spend searching.
pub search_time_limit: std::time::Duration,
/// Optional maximum runtime-specific intermediate memory, in bytes.
///
/// When this is `None`, search does not apply a memory cap. Runtimes that
@@ -163,8 +167,6 @@ pub struct CompileOptions {
/// Per-candidate profiling timeout. If a profile call reaches this budget,
/// that candidate is discarded and search continues.
pub profile_timeout: Option<std::time::Duration>,
/// Optional per-group search timeout.
pub group_timeout: Option<std::time::Duration>,
/// Optional profiling dimension overrides.
pub profile_dims: FxHashMap<char, usize>,
/// Bucket definitions per dynamic dimension. Dimensions without buckets use
@@ -173,20 +175,16 @@ pub struct CompileOptions {
}
impl CompileOptions {
/// Create compile options with the given search limit. Other fields use defaults.
pub fn new(limit: usize) -> Self {
Self {
limit,
max_memory_bytes: None,
generation_size: 10,
mutations: 10,
trials: 3,
keep_best: 1,
profile_timeout: Some(std::time::Duration::from_secs(1)),
group_timeout: None,
profile_dims: FxHashMap::default(),
dim_buckets: FxHashMap::default(),
}
/// Set the maximum number of graphs to evaluate during search.
pub fn search_graph_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
/// Set the maximum wall-clock time to spend searching.
pub fn search_time_limit(mut self, search_time_limit: std::time::Duration) -> Self {
self.search_time_limit = search_time_limit;
self
}
/// Set a maximum intermediate memory budget in bytes.
@@ -235,12 +233,6 @@ impl CompileOptions {
self
}
/// Set an optional per-group search timeout.
pub fn group_timeout(mut self, group_timeout: std::time::Duration) -> Self {
self.group_timeout = Some(group_timeout);
self
}
/// Override a dynamic dimension value used during search profiling.
pub fn profile_dim(mut self, dim: char, value: usize) -> Self {
self.profile_dims.insert(dim, value);
@@ -261,7 +253,18 @@ impl CompileOptions {
impl Default for CompileOptions {
fn default() -> Self {
Self::new(1)
Self {
limit: 100,
search_time_limit: std::time::Duration::MAX,
max_memory_bytes: None,
generation_size: 10,
mutations: 10,
trials: 3,
keep_best: 1,
profile_timeout: Some(std::time::Duration::from_secs(1)),
profile_dims: FxHashMap::default(),
dim_buckets: FxHashMap::default(),
}
}
}
@@ -1302,10 +1305,18 @@ impl Graph {
"dim buckets must be configured in CompileOptions before build_search_space; search cannot change buckets after build",
);
let search_started_at = std::time::Instant::now();
if self.search_space_dim_buckets.is_empty() {
// No buckets: existing single-search path
let stitched =
self.search_single(&mut runtime, &options, rng, &self.dyn_map.clone(), None, 0);
let stitched = self.search_single(
&mut runtime,
&options,
rng,
&self.dyn_map.clone(),
None,
0,
search_started_at,
);
runtime.clear_intermediate_buffers();
runtime.load_llir(&stitched);
@@ -1339,6 +1350,7 @@ impl Graph {
&context.representative_dyn_map,
Some((combo_idx, n_combos)),
combo_idx,
search_started_at,
);
bucket_llirs.push((
context.bucket_indices,
@@ -1425,6 +1437,7 @@ impl Graph {
/// Run the genetic search and return the unrolled LLIR for the winning
/// genome. `bucket_progress`: if `Some((current_bucket_idx, total_buckets))`
/// adds a second "Bucket" progress bar.
#[allow(clippy::too_many_arguments)]
fn search_single<R: Runtime + 'static, G: rand::Rng>(
&mut self,
runtime: &mut R,
@@ -1433,6 +1446,7 @@ impl Graph {
dyn_map: &FxHashMap<char, usize>,
bucket_progress: Option<(usize, usize)>,
egraph_index: usize,
search_started_at: std::time::Instant,
) -> LLIRGraph {
let mut profile_dyn_map = dyn_map.clone();
for (&dim, &value) in &options.profile_dims {
@@ -1482,11 +1496,11 @@ impl Graph {
}
};
let group_start = std::time::Instant::now();
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
let mut list_cache = FxHashMap::default();
let mut expr_cache = FxHashMap::default();
runtime.clear_intermediate_buffers();
let search_time_limit_reached = || search_started_at.elapsed() >= options.search_time_limit;
let profile_timed_out = |elapsed: std::time::Duration| {
options
.profile_timeout
@@ -1561,11 +1575,8 @@ impl Graph {
break;
}
Ok(_) | Err(_) => {
if options
.group_timeout
.is_some_and(|timeout| group_start.elapsed() >= timeout)
{
panic!("Failed to find a viable initial genome before timeout");
if search_time_limit_reached() {
panic!("Failed to find a viable initial genome before search time limit");
}
list_cache.clear();
expr_cache.clear();
@@ -1586,10 +1597,7 @@ impl Graph {
let mut resample_generation = false;
while n_graphs < search_limit {
if options
.group_timeout
.is_some_and(|timeout| group_start.elapsed() >= timeout)
{
if search_time_limit_reached() {
break;
}
@@ -1623,10 +1631,7 @@ impl Graph {
let mut generation_found_non_timeout = false;
for genome in all_offspring {
if options
.group_timeout
.is_some_and(|timeout| group_start.elapsed() >= timeout)
{
if search_time_limit_reached() {
break;
}
n_graphs += 1;
@@ -2808,6 +2813,7 @@ mod tests {
use super::*;
use crate::egglog_utils::hash_egglog_normalized;
use crate::tests::{assert_close, random_vec};
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Default)]
struct TestMemoryRuntime;
@@ -2845,6 +2851,69 @@ mod tests {
}
}
static PROFILE_CALLS: AtomicUsize = AtomicUsize::new(0);
#[derive(Default)]
struct CountingRuntime;
impl Runtime for CountingRuntime {
type Ops = ();
type CompileArg = ();
type ExecReturn = ();
type ProfileMetric = usize;
fn initialize(_: Self::CompileArg) -> Self {
Self
}
fn load_llir(&mut self, _: &LLIRGraph) {}
fn execute(&mut self, _: &FxHashMap<char, usize>) -> Self::ExecReturn {}
fn profile(
&mut self,
_: &LLIRGraph,
_: &FxHashMap<char, usize>,
_: usize,
_: Option<std::time::Duration>,
) -> (Self::ProfileMetric, String) {
let count = PROFILE_CALLS.fetch_add(1, Ordering::SeqCst);
(count, format!("{count} ms"))
}
}
#[test]
fn compile_options_defaults_and_search_time_limit_builder() {
let opts = CompileOptions::default();
assert_eq!(opts.limit, 100);
assert_eq!(opts.search_time_limit, std::time::Duration::MAX);
let time_limit = std::time::Duration::from_millis(25);
let opts = CompileOptions::default()
.search_graph_limit(7)
.search_time_limit(time_limit);
assert_eq!(opts.limit, 7);
assert_eq!(opts.search_time_limit, time_limit);
}
#[test]
fn search_time_limit_stops_after_initial_viable_candidate() {
let mut cx = Graph::new();
let a = cx.tensor((4, 8));
let b = cx.tensor((8, 4));
let c = cx.tensor((4, 4));
let _ = (a.matmul(b) + c).relu().softmax(1).output();
cx.build_search_space::<CountingRuntime>(CompileOptions::default());
PROFILE_CALLS.store(0, Ordering::SeqCst);
let _ = cx.search(
CountingRuntime,
CompileOptions::default().search_time_limit(std::time::Duration::ZERO),
);
assert_eq!(PROFILE_CALLS.load(Ordering::SeqCst), 1);
}
#[test]
fn build_search_space_without_explicit_max_memory_has_no_cap() {
let mut cx = Graph::new();
@@ -2862,7 +2931,10 @@ mod tests {
let _ = cx.tensor(1).output();
cx.build_search_space::<TestMemoryRuntime>(CompileOptions::default().max_memory_bytes(0));
let _ = cx.search(TestMemoryRuntime, CompileOptions::new(1));
let _ = cx.search(
TestMemoryRuntime,
CompileOptions::default().search_graph_limit(1),
);
}
#[test]
@@ -2870,7 +2942,10 @@ mod tests {
let mut cx = Graph::new();
let _ = cx.tensor(1).output();
let _ = cx.compile(TestMemoryRuntime, CompileOptions::new(1));
let _ = cx.compile(
TestMemoryRuntime,
CompileOptions::default().search_graph_limit(1),
);
assert_eq!(cx.egraphs.len(), 1);
assert_eq!(cx.search_space_max_memory_bytes, None);
@@ -2908,7 +2983,7 @@ mod tests {
let mut rng = rand::rng();
let _ = cx.search_with_rng(
TestMemoryRuntime,
CompileOptions::new(1).dim_buckets(
CompileOptions::default().search_graph_limit(1).dim_buckets(
's',
&[DimBucket::new(1, 1), DimBucket::new(2, 4).representative(4)],
),
@@ -3019,7 +3094,7 @@ mod tests {
let vals = random_vec(8);
let mut rt = NativeRuntime::default();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.set_data(x.id, vals.clone());
rt.execute(&cx.dyn_map);
@@ -3056,7 +3131,7 @@ mod tests {
let vals = random_vec(8);
let mut rt = NativeRuntime::default();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
rt = cx.search(rt, CompileOptions::new(1));
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
rt.set_data(x.id, vals.clone());
rt.execute(&cx.dyn_map);

View File

@@ -24,7 +24,7 @@ proptest! {
let d = (b * c / e).sin().output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
rt.set_data(b.id, vals.clone());
rt.set_data(c.id, vals.clone());
@@ -58,7 +58,7 @@ proptest! {
let a = b.matmul(c).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
let lhs = lhs.into_iter().take(m * k).collect::<Vec<f32>>();
let rhs = rhs.into_iter().take(k * n).collect::<Vec<f32>>();
rt.set_data(b.id, lhs.clone());
@@ -82,7 +82,7 @@ proptest! {
let a = cx.tensor((2, 2));
let b = (a.permute((1, 0)) * 1.0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
rt.set_data(a.id, values.clone());
rt.execute(&cx.dyn_map);
@@ -99,7 +99,7 @@ proptest! {
let mask = a.ge(kth_largest.expand_dim(1, cols)).cast(crate::dtype::DType::F32);
let filtered = (a * mask).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
let values = values.into_iter().take(rows * cols).collect::<Vec<f32>>();
rt.set_data(a.id, values.clone());
rt.execute(&cx.dyn_map);
@@ -465,7 +465,10 @@ fn test_inputs_consumed_after_execute() {
let a = cx.tensor(3);
let _b = (a * 2.0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(a.id, vec![1.0, 2.0, 3.0]);
rt.execute(&cx.dyn_map);
// Second execute should panic — input 'a' was consumed
@@ -481,7 +484,10 @@ fn test_passthrough_preserves_weights() {
w.persist();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
// Iteration 1
rt.set_data(w.id, vec![1.0, 2.0, 3.0]);
@@ -503,7 +509,10 @@ fn test_only_outputs_remain() {
let a = cx.tensor(3);
let _b = (a * 2.0).output();
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = cx.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(a.id, vec![1.0, 2.0, 3.0]);
rt.execute(&cx.dyn_map);
let output_count = rt
@@ -556,7 +565,10 @@ fn integration_auto_loop_rolling_matches_reference_native_runtime() {
let (mut graph, input_id, weight_ids, output_id) = build_repeated_block_graph(layers, width);
graph.build_search_space::<NativeRuntime>(CompileOptions::default());
let mut rt = graph.search(NativeRuntime::default(), CompileOptions::new(1));
let mut rt = graph.search(
NativeRuntime::default(),
CompileOptions::default().search_graph_limit(1),
);
rt.set_data(input_id, input);
for (node, data) in weight_ids.iter().zip(weights.iter()) {
rt.set_data(*node, data.clone());