mirror of
https://github.com/jafioti/luminal.git
synced 2026-06-01 21:49:47 +09:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75e4e6be0a |
@@ -21,8 +21,7 @@ let b = cx.tensor((1, 4));
|
||||
let c = a.matmul(b).output();
|
||||
|
||||
// Compile
|
||||
cx.build_search_space::<NativeRuntime>();
|
||||
let mut rt = cx.search(NativeRuntime::default(), 1);
|
||||
let mut rt = cx.compile(NativeRuntime::default(), CompileOptions::default());
|
||||
|
||||
// Set input tensors
|
||||
rt.set_data(a, vec![1.0, 2.0, 3.0]);
|
||||
|
||||
@@ -50,7 +50,7 @@ fn run_metal_pattern_benchmark(
|
||||
}
|
||||
}
|
||||
|
||||
let mut rt = cx.search(rt, CompileOptions::new(5));
|
||||
let mut rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
|
||||
let mut bench_metrics = None;
|
||||
|
||||
@@ -50,7 +50,7 @@ fn prepare_and_search(cx: &mut Graph, input_sizes: &[(NodeIndex, usize)]) -> Opt
|
||||
rt.set_data(*node, &data);
|
||||
}
|
||||
|
||||
let rt = cx.search(rt, CompileOptions::new(5));
|
||||
let rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
Some(PreparedBench {
|
||||
rt,
|
||||
|
||||
@@ -499,7 +499,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result1 = rt.get_f32(c);
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -531,7 +531,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
let mut results = Vec::new();
|
||||
for _ in 0..5 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -569,7 +569,7 @@ mod tests {
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.set_dim('s', size);
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a
|
||||
.iter()
|
||||
@@ -611,7 +611,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let expected: Vec<f32> = data_a.iter().zip(&data_b).map(|(a, b)| a + b).collect();
|
||||
let eps = dtype_epsilon(luminal::dtype::DType::F32);
|
||||
@@ -642,7 +642,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
for _ in 0..10 {
|
||||
rt.execute(&cx.dyn_map);
|
||||
}
|
||||
@@ -675,7 +675,7 @@ mod tests {
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Initial execution
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1556,37 +1556,9 @@ impl Runtime for CudaRuntime {
|
||||
self.profiling = false;
|
||||
let duration = durations.iter().sum::<std::time::Duration>() / durations.len() as u32;
|
||||
|
||||
let total_bytes: usize = self
|
||||
.last_kernel_stats
|
||||
.iter()
|
||||
.map(|s| s.bytes_loaded + s.bytes_stored)
|
||||
.sum::<usize>();
|
||||
let total_flops: usize = self
|
||||
.last_kernel_stats
|
||||
.iter()
|
||||
.map(|s| s.flops)
|
||||
.sum::<usize>();
|
||||
let aggregate_bw = if self.last_total_time_us > 0.0 {
|
||||
(total_bytes as f64) / (self.last_total_time_us * 1e-6) / 1e9
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let aggregate_tf = if self.last_total_time_us > 0.0 {
|
||||
(total_flops as f64) / (self.last_total_time_us * 1e-6) / 1e12
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let peak_bw = crate::cuda_bandwidth_gbps(self.cuda_stream.context());
|
||||
let peak_tf = crate::cuda_compute_f32_tflops(self.cuda_stream.context());
|
||||
let mbu = peak_bw.map(|p| aggregate_bw / p as f64);
|
||||
let mfu = peak_tf.map(|p| aggregate_tf / p as f64);
|
||||
|
||||
let duration_str = format_duration_precise(&duration);
|
||||
let mbu_str = mbu.map_or("-".to_string(), |v| format!("{:.1}%", v * 100.0));
|
||||
let mfu_str = mfu.map_or("-".to_string(), |v| format!("{:.1}%", v * 100.0));
|
||||
let display = format!(
|
||||
"{duration_str} | MBU: {mbu_str} | MFU: {mfu_str} [KRN: {} HOST: {}]",
|
||||
"{duration_str} | [KRN: {} HOST: {}]",
|
||||
llir_graph
|
||||
.node_weights()
|
||||
.filter(|n| n.to_dialect::<dyn KernelOp>().is_some())
|
||||
@@ -2079,18 +2051,10 @@ impl CudaRuntime {
|
||||
if !self.last_kernel_stats.is_empty() {
|
||||
println!("\n=== Kernel Execution Statistics ===\n");
|
||||
println!(
|
||||
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
|
||||
"Kernel",
|
||||
"Time (us)",
|
||||
"Loaded",
|
||||
"Stored",
|
||||
"Agg FLOPS",
|
||||
"BW (GB/s)",
|
||||
"TFLOPS",
|
||||
"MBU",
|
||||
"MFU"
|
||||
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12}",
|
||||
"Kernel", "Time (us)", "Loaded", "Stored", "Agg FLOPS", "BW (GB/s)", "TFLOPS"
|
||||
);
|
||||
println!("{}", "-".repeat(116));
|
||||
println!("{}", "-".repeat(92));
|
||||
for s in &self.last_kernel_stats {
|
||||
self.print_stat_row(
|
||||
s.name,
|
||||
@@ -2101,29 +2065,20 @@ impl CudaRuntime {
|
||||
s.flops,
|
||||
s.bandwidth_gbps,
|
||||
s.tflops,
|
||||
peak_bw,
|
||||
peak_tf,
|
||||
);
|
||||
}
|
||||
println!("{}", "-".repeat(116));
|
||||
println!("{}", "-".repeat(92));
|
||||
}
|
||||
|
||||
// Print aggregate stats
|
||||
println!("\n=== Aggregate Statistics ===\n");
|
||||
println!(
|
||||
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
|
||||
"", "Time (us)", "Loaded", "Stored", "Agg FLOPS", "BW (GB/s)", "TFLOPS", "MBU", "MFU"
|
||||
"{:<20} {:>12} {:>12} {:>12} {:>12} {:>12} {:>12}",
|
||||
"", "Time (us)", "Loaded", "Stored", "Agg FLOPS", "BW (GB/s)", "TFLOPS"
|
||||
);
|
||||
println!("{}", "-".repeat(116));
|
||||
let (mbu, mfu) = match (peak_bw, peak_tf) {
|
||||
(Some(pb), Some(pt)) => (
|
||||
format!("{:.1}%", aggregate_bw / pb as f64 * 100.0),
|
||||
format!("{:.1}%", aggregate_tf / pt as f64 * 100.0),
|
||||
),
|
||||
_ => ("-".into(), "-".into()),
|
||||
};
|
||||
println!("{}", "-".repeat(92));
|
||||
println!(
|
||||
"{:<20} {:>12.2} {:>12} {:>12} {:>12} {:>12} {:>12} {:>8} {:>8}",
|
||||
"{:<20} {:>12.2} {:>12} {:>12} {:>12} {:>12} {:>12}",
|
||||
"Total",
|
||||
self.last_total_time_us,
|
||||
format_size(total_bytes_loaded),
|
||||
@@ -2131,8 +2086,6 @@ impl CudaRuntime {
|
||||
format_flops(total_flops),
|
||||
format!("{:.2}", aggregate_bw),
|
||||
format!("{:.4}", aggregate_tf),
|
||||
mbu,
|
||||
mfu
|
||||
);
|
||||
|
||||
if let (Some(pb), Some(pt)) = (peak_bw, peak_tf) {
|
||||
@@ -2152,8 +2105,6 @@ impl CudaRuntime {
|
||||
flops: usize,
|
||||
bw: f64,
|
||||
tf: f64,
|
||||
peak_bw: Option<usize>,
|
||||
peak_tf: Option<usize>,
|
||||
) {
|
||||
let total = loaded + stored;
|
||||
let ld = if loaded > 0 {
|
||||
@@ -2181,21 +2132,13 @@ impl CudaRuntime {
|
||||
} else {
|
||||
"-".into()
|
||||
};
|
||||
let mbu = peak_bw
|
||||
.filter(|_| total > 0)
|
||||
.map(|p| format!("{:.1}%", bw / p as f64 * 100.0))
|
||||
.unwrap_or("-".into());
|
||||
let mfu = peak_tf
|
||||
.filter(|_| flops > 0)
|
||||
.map(|p| format!("{:.1}%", tf / p as f64 * 100.0))
|
||||
.unwrap_or("-".into());
|
||||
|
||||
match count {
|
||||
Some(c) => println!(
|
||||
"{name:<20} {time_us:>12.2} {c:>8} {ld:>12} {st:>12} {fl:>12} {bw_s:>12} {tf_s:>12} {mbu:>8} {mfu:>8}"
|
||||
"{name:<20} {time_us:>12.2} {c:>8} {ld:>12} {st:>12} {fl:>12} {bw_s:>12} {tf_s:>12}"
|
||||
),
|
||||
None => println!(
|
||||
"{name:<20} {time_us:>12.2} {ld:>12} {st:>12} {fl:>12} {bw_s:>12} {tf_s:>12} {mbu:>8} {mfu:>8}"
|
||||
"{name:<20} {time_us:>12.2} {ld:>12} {st:>12} {fl:>12} {bw_s:>12} {tf_s:>12}"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,11 @@ fn test_bucket_dispatch_simple() {
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Test bucket 1: s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -91,7 +95,11 @@ fn test_bucket_matmul_dynamic() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute at s=1
|
||||
cx.set_dim('s', 1);
|
||||
@@ -146,7 +154,11 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let input_data = random_f32_vec(12, seed, -1.0, 1.0);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
let mut rng1 = SmallRng::seed_from_u64(seed);
|
||||
rt1 = cx1.search_with_rng(rt1, CompileOptions::new(5), &mut rng1);
|
||||
rt1 = cx1.search_with_rng(
|
||||
rt1,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng1,
|
||||
);
|
||||
rt1.set_data(a1, input_data.clone());
|
||||
rt1.execute(&cx1.dyn_map);
|
||||
let result_unbucketed = rt1.get_f32(b1);
|
||||
@@ -158,7 +170,11 @@ fn test_bucket_results_match_unbucketed() {
|
||||
let mut rt2 = CudaRuntime::initialize(stream.clone());
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
let mut rng2 = SmallRng::seed_from_u64(seed);
|
||||
rt2 = cx2.search_with_rng(rt2, CompileOptions::new(5), &mut rng2);
|
||||
rt2 = cx2.search_with_rng(
|
||||
rt2,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng2,
|
||||
);
|
||||
rt2.set_data(a2, input_data.clone());
|
||||
rt2.execute(&cx2.dyn_map);
|
||||
let result_bucketed = rt2.get_f32(b2);
|
||||
@@ -186,7 +202,11 @@ fn test_bucket_out_of_range_panics() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// s=10 is outside all buckets — should panic
|
||||
cx.set_dim('s', 10);
|
||||
@@ -211,7 +231,11 @@ fn test_bucket_no_buckets_backward_compat() {
|
||||
let input_data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
rt.set_data(a, input_data.clone());
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -257,7 +281,11 @@ fn test_bucket_switch_preserves_weights() {
|
||||
rt.set_data(b_tensor, b_data.clone());
|
||||
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute with bucket 1 (s=1)
|
||||
cx.set_dim('s', 1);
|
||||
@@ -311,7 +339,11 @@ fn test_bucket_multiple_executions_same_bucket() {
|
||||
cx.set_dim('s', 1);
|
||||
rt.set_data(a, vec![1.0f32; 4]);
|
||||
let mut rng = SmallRng::seed_from_u64(42);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(3), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(3),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Execute at different sizes within the same bucket
|
||||
for s in [1, 2, 4, 8] {
|
||||
|
||||
@@ -307,7 +307,7 @@ fn test_scatter_kv_cache_roundtrip() {
|
||||
rt.set_data(src, vec![10.0f32]);
|
||||
rt.set_data(indexes, vec![0i32]);
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Print and verify which scatter variant was selected
|
||||
let scatter_names: Vec<_> = rt
|
||||
@@ -427,7 +427,11 @@ fn test_scatter_dual_cache() {
|
||||
|
||||
// Use seeded search for deterministic variant selection.
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(5), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(5),
|
||||
&mut rng,
|
||||
);
|
||||
|
||||
// Print and verify selected variants
|
||||
let scatter_names: Vec<_> = rt
|
||||
@@ -554,7 +558,11 @@ fn test_scatter_rows_dynamic_prefill_roundtrip() {
|
||||
rt.set_data(gather_idx, scatter);
|
||||
rt.set_data(cache, (0..SLOTS * D).map(|i| i as f32).collect::<Vec<_>>());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_eq!(rt.get_f32(gathered), expected_gather);
|
||||
@@ -763,7 +771,11 @@ fn test_tiny_gqa_attention_batched_matches_sequential_prefill() {
|
||||
rt.set_data(k_cache, zero_k.clone());
|
||||
rt.set_data(v_cache, zero_v.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let batched_attn = rt.get_f32(attn_out);
|
||||
let batched_k = rt.get_f32(k_out);
|
||||
@@ -865,7 +877,11 @@ fn test_original_gqa_attention_batched_matches_sequential_prefill() {
|
||||
rt.set_data(k_cache, zero_k.clone());
|
||||
rt.set_data(v_cache, zero_v.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let batched_attn = rt.get_f32(attn_out);
|
||||
let batched_k = rt.get_f32(k_out);
|
||||
@@ -937,7 +953,11 @@ fn test_dynamic_expanded_causal_mask_softmax() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(mask, mask_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(weights);
|
||||
|
||||
@@ -1007,7 +1027,11 @@ fn test_tiny_gqa_value_matmul_with_expanded_kv() {
|
||||
rt.set_data(v_full, v_data.clone());
|
||||
rt.set_data(mask, mask_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1073,7 +1097,11 @@ fn test_broadcast_merge_gqa_value_matmul_matches_cpu() {
|
||||
rt.set_data(v_full, v_data.clone());
|
||||
rt.set_data(weights, weights_data);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1124,7 +1152,11 @@ fn test_transpose_merge_split_roundtrip_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(x, x_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(roundtrip);
|
||||
|
||||
@@ -1171,7 +1203,11 @@ fn test_batched_moe_x_expand_matmul_matches_cpu() {
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(w, w_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1220,7 +1256,11 @@ fn test_batched_topk_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(topk);
|
||||
|
||||
@@ -1259,7 +1299,11 @@ fn test_batched_argsort_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(argsort);
|
||||
|
||||
@@ -1299,7 +1343,11 @@ fn test_dynamic_3d_sum_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
@@ -1356,7 +1404,11 @@ fn test_batched_argsort_ranks_axis1_matches_cpu() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(routing, routing_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(ranks);
|
||||
|
||||
@@ -1395,7 +1447,11 @@ fn test_dynamic_3d_flat_index_iota_rows() {
|
||||
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(idx);
|
||||
|
||||
@@ -1438,7 +1494,11 @@ fn test_dynamic_2d_to_3d_gather_rows() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(data, data_values.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_i32(out);
|
||||
|
||||
@@ -1490,7 +1550,11 @@ fn test_batched_gather_experts_matches_cpu() {
|
||||
rt.set_data(topk, topk_data.clone());
|
||||
rt.set_data(weights, weights_data.clone());
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(0);
|
||||
rt = cx.search_with_rng(rt, CompileOptions::new(10), &mut rng);
|
||||
rt = cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(10),
|
||||
&mut rng,
|
||||
);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(out);
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ fn run_reference_attention(
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
rt.set_data(v_t, v.to_vec());
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
|
||||
rt.set_data(q_t, q.to_vec());
|
||||
rt.set_data(k_t, k.to_vec());
|
||||
|
||||
@@ -61,7 +61,7 @@ fn generic_matmul_executes_noncontiguous_merged_head_projection() {
|
||||
rt.set_data(attn, attn_data.as_slice());
|
||||
rt.set_data(weight, weight_data.as_slice());
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
assert!(
|
||||
rt.kernel_names().contains(&"GenericMatmul"),
|
||||
"expected GenericMatmul to be selected, kernels: {:?}",
|
||||
|
||||
@@ -95,7 +95,7 @@ fn fuzz_mlp(seq: usize, hidden: usize, intermediate: usize, seed: u64) {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -156,7 +156,7 @@ fn fuzz_norm_proj(seq: usize, hidden: usize, proj_dim: usize, eps: f32, seed: u6
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(proj_w, proj_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -245,7 +245,7 @@ fn fuzz_layer_no_attn(
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -330,7 +330,7 @@ fn fuzz_mlp_hlir_only(seq: usize, hidden: usize, intermediate: usize, seed: u64)
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -518,7 +518,7 @@ mod gemma {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -655,7 +655,7 @@ mod qwen {
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(norm_w, norm_data.clone());
|
||||
rt.set_data(embedding, emb_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ fn run_argsort_test(rows: usize, cols: usize, seed: u64) {
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
rt.set_data(input, data);
|
||||
rt = cx.search(rt, CompileOptions::new(10));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let out_dim0 = rt.get_i32(sorted_dim0.id);
|
||||
let out_dim1 = rt.get_i32(sorted_dim1.id);
|
||||
@@ -600,7 +600,7 @@ fn run_embed_test(vocab_size: usize, embed_dim: usize, seq_len: usize, seed: u64
|
||||
|
||||
rt.set_data(token_ids, token_data.clone());
|
||||
rt.set_data(embed_table, embed_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(output);
|
||||
|
||||
@@ -31,7 +31,7 @@ pub fn kernel_add_bandwidth_test() {
|
||||
let mut rt = CudaRuntime::initialize(stream.clone());
|
||||
rt.set_data(a, data_a.clone());
|
||||
rt.set_data(b, data_b.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
// Warm up
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -233,7 +233,9 @@ fn run_qwen_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
rt.set_data(model.router, router_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, CompileOptions::new(10));
|
||||
rt = model
|
||||
.graph
|
||||
.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
rt.get_f32(model.output.id)
|
||||
@@ -278,7 +280,9 @@ fn run_gemma_moe(include_glumoe: bool) -> Vec<f32> {
|
||||
rt.set_data(model.per_expert_scale, per_expert_scale_data);
|
||||
rt.set_data(model.gate_up_weights, gate_up_data);
|
||||
rt.set_data(model.down_weights, down_data);
|
||||
rt = model.graph.search(rt, CompileOptions::new(10));
|
||||
rt = model
|
||||
.graph
|
||||
.search(rt, CompileOptions::default().search_graph_limit(10));
|
||||
rt.execute(&model.graph.dyn_map);
|
||||
|
||||
rt.get_f32(model.output.id)
|
||||
|
||||
@@ -50,7 +50,7 @@ fn rope_matches_cpu_reference() {
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
@@ -98,7 +98,7 @@ fn rope_flux2_shape() {
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt.set_data(cos, cos_data.clone());
|
||||
rt.set_data(sin, sin_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let got = rt.get_f32(y.id);
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ fn test_mini_transformer_layer() {
|
||||
|
||||
// Use minimal search iterations to avoid excessive graph rewriting
|
||||
// which can cause float drift through softmax/RMSNorm reordering
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -316,7 +316,7 @@ fn test_mini_transformer_two_layers() {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -372,7 +372,7 @@ fn test_transformer_multi_seed() {
|
||||
rt.set_data(*tensor, data.clone());
|
||||
}
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -404,7 +404,7 @@ fn test_rms_norm_cuda() {
|
||||
.collect();
|
||||
rt.set_data(input, input_data.clone());
|
||||
rt.set_data(weight, weight_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -447,7 +447,7 @@ fn test_self_attention_cuda() {
|
||||
rt.set_data(wk, wk_data.clone());
|
||||
rt.set_data(wv, wv_data.clone());
|
||||
rt.set_data(wo, wo_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -491,7 +491,7 @@ fn test_swiglu_mlp_cuda() {
|
||||
rt.set_data(w_gate, gate_data.clone());
|
||||
rt.set_data(w_up, up_data.clone());
|
||||
rt.set_data(w_down, down_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -530,7 +530,7 @@ fn test_rolled_chained_scalar_muls() {
|
||||
let mut rt = CudaRuntime::initialize(stream);
|
||||
let x_data = random_f32_vec(4 * 32, 101, -0.5, 0.5);
|
||||
rt.set_data(x, x_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(out);
|
||||
|
||||
@@ -331,7 +331,7 @@ pub fn fuzz_cuda_search_space_equivalence(
|
||||
let mut native_rng = StdRng::seed_from_u64(config.seed);
|
||||
let mut native_rt = cx.search_with_rng(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::new(1),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
&mut native_rng,
|
||||
);
|
||||
for input in inputs {
|
||||
@@ -701,7 +701,7 @@ pub fn test_unary_cuda<T: TestDType>(
|
||||
|
||||
let input_data = generator(n_elements, seed);
|
||||
rt.set_data(a, input_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, b.id);
|
||||
@@ -776,7 +776,7 @@ pub fn test_binary_cuda<T: TestDType>(
|
||||
let b_data = b_generator(b_elements, seed.wrapping_add(1));
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = T::get_from_runtime(&rt, c.id);
|
||||
@@ -844,7 +844,7 @@ pub fn test_mod(
|
||||
let b_data = random_f32_vec(b_elements, seed.wrapping_add(1), 0.1, 0.5);
|
||||
rt.set_data(a, a_data.clone());
|
||||
rt.set_data(b, b_data.clone());
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
let result = rt.get_f32(c);
|
||||
|
||||
@@ -503,7 +503,8 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
runtime.set_data(scatter_idx_t, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(gather_idx_t, (0..search_c as i32).collect::<Vec<_>>());
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
runtime = cx.search(runtime, CompileOptions::new(SEARCH_GRAPHS));
|
||||
let search_options = CompileOptions::default().search_graph_limit(SEARCH_GRAPHS);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
println!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
|
||||
@@ -41,7 +41,11 @@ fn bytes_of<T: bytemuck::NoUninit>(values: &[T]) -> Vec<u8> {
|
||||
|
||||
fn search_candidates(cx: &mut Graph, rt: MetalRuntime, limit: usize) -> MetalRuntime {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
cx.search_with_rng(rt, CompileOptions::new(limit), &mut rng)
|
||||
cx.search_with_rng(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(limit),
|
||||
&mut rng,
|
||||
)
|
||||
}
|
||||
|
||||
fn egraph_has_op(cx: &Graph, op_name: &str) -> bool {
|
||||
@@ -301,7 +305,7 @@ fn dynamic_dim_sum_reduce_runs() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -322,7 +326,7 @@ fn metal_bucketed_dynamic_dim_dispatches_correct_graph() {
|
||||
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, vec![1.0f32; 4]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
|
||||
cx.set_dim('s', 1);
|
||||
let s1_input = vec![1.0, 2.0, 3.0, 4.0];
|
||||
@@ -350,7 +354,7 @@ fn metal_int_arithmetic_preserves_large_values() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(token, &[16_385i32]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -373,7 +377,7 @@ proptest! {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -395,7 +399,7 @@ proptest! {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -417,7 +421,7 @@ proptest! {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
let input_values: Vec<f32> = values.into_iter().take(len).collect();
|
||||
rt.set_data(input, &input_values);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -449,7 +453,7 @@ fn metal_simple_add() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -469,7 +473,7 @@ fn metal_simple_mul() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -487,7 +491,7 @@ fn metal_simple_exp2() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[0.0, 1.0, 2.0, 3.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -504,7 +508,7 @@ fn metal_simple_log2() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 4.0, 8.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -529,7 +533,7 @@ fn metal_simple_sin() {
|
||||
3.0 * std::f32::consts::FRAC_PI_2,
|
||||
],
|
||||
);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -546,7 +550,7 @@ fn metal_simple_sqrt() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 4.0, 9.0, 16.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -563,7 +567,7 @@ fn metal_simple_recip() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, 2.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -582,7 +586,7 @@ fn metal_simple_mod() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[7.0, 10.0, 15.0, 8.5]);
|
||||
rt.set_data(b, &[3.0, 4.0, 6.0, 2.5]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -601,7 +605,7 @@ fn metal_simple_less_than() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 5.0, 3.0, 4.0]);
|
||||
rt.set_data(b, &[2.0, 3.0, 3.0, 5.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -621,7 +625,7 @@ fn metal_simple_sum_reduce() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
// [[1,2,3,4], [5,6,7,8]] -> [10, 26]
|
||||
rt.set_data(input, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -640,7 +644,7 @@ fn metal_simple_max_reduce() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
// [[1,4,2,3], [8,5,7,6]] -> [4, 8]
|
||||
rt.set_data(input, &[1.0, 4.0, 2.0, 3.0, 8.0, 5.0, 7.0, 6.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(5));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(5));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -657,7 +661,7 @@ fn metal_f16_cast_roundtrip() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(input, &[1.0, -2.5, 3.25, 4.75]);
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -678,7 +682,7 @@ fn metal_f16_intermediate_add_roundtrip() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(a, &[1.0, 2.0, -3.0, 4.5]);
|
||||
rt.set_data(b, &[0.5, -1.0, 3.0, 0.25]);
|
||||
rt = cx.search(rt, CompileOptions::new(3));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(3));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1026,7 +1030,7 @@ fn metal_rms_norm() {
|
||||
|
||||
rt.set_data(input, &input_data);
|
||||
rt.set_data(weight, &weight_data);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1066,7 +1070,7 @@ fn metal_self_attention() {
|
||||
rt.set_data(wk, &wk_data);
|
||||
rt.set_data(wv, &wv_data);
|
||||
rt.set_data(wo, &wo_data);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1125,7 +1129,7 @@ fn metal_self_attention_f16_weights() {
|
||||
rt.set_data(wk, to_f16_vec(&wk_data));
|
||||
rt.set_data(wv, to_f16_vec(&wv_data));
|
||||
rt.set_data(wo, to_f16_vec(&wo_data));
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1169,7 +1173,7 @@ fn metal_swiglu_mlp() {
|
||||
rt.set_data(w_gate, &gate_data);
|
||||
rt.set_data(w_up, &up_data);
|
||||
rt.set_data(w_down, &down_data);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1219,7 +1223,7 @@ fn metal_mini_transformer_layer() {
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data);
|
||||
}
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1285,7 +1289,7 @@ fn metal_mini_transformer_layer_f16_intermediate() {
|
||||
for (tensor, data) in &weight_data {
|
||||
rt.set_data(*tensor, data);
|
||||
}
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1327,7 +1331,7 @@ fn test_scatter_basic() {
|
||||
rt.set_data(src, &[10.0, 20.0, 30.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0, 4.0]);
|
||||
rt.set_data(dest, &[0.0, 0.0, 0.0, 0.0, 0.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1349,7 +1353,7 @@ fn test_scatter_buffer_roundtrip() {
|
||||
rt.set_data(src, &[0.0]);
|
||||
rt.set_data(indexes, &[0.0]);
|
||||
rt.set_zeros(cache, 4 * std::mem::size_of::<f32>());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
|
||||
for (pos, value, expected) in [
|
||||
(0, 10.0, [10.0, 0.0, 0.0, 0.0]),
|
||||
@@ -1383,7 +1387,7 @@ fn test_load_safetensors_f32_survives_search_and_overrides_input_data() {
|
||||
rt.set_data(weights, &[99.0, 99.0, 99.0]);
|
||||
rt.set_data(bias, &[0.5, 1.0, -1.5]);
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1448,7 +1452,7 @@ fn test_load_safetensors_converts_supported_float_dtypes() {
|
||||
cx.build_search_space::<MetalRuntime>(CompileOptions::default());
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.load_safetensors(&cx, path.to_str().unwrap());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1475,7 +1479,7 @@ fn test_gather_noncontiguous_data_uses_data_shape() {
|
||||
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
|
||||
);
|
||||
rt.set_data(indexes, &[0.0, 3.0, 4.0, 7.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1495,7 +1499,7 @@ fn test_scatter_into_nonzero_dest() {
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[2f32]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1522,7 +1526,7 @@ fn test_scatter_no_copy_remove_buffer_aliases_dest() {
|
||||
rt.set_data(src, &[7.0, 8.0]);
|
||||
rt.set_data(indexes, &[1.0, 3.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0, 50.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1551,7 +1555,7 @@ fn test_scatter_no_copy_handles_2d_destination() {
|
||||
rt.set_data(src, &[9.0, 8.0]);
|
||||
rt.set_data(indexes, &[2.0, 4.0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1578,7 +1582,7 @@ fn test_scatter_no_copy_not_selected_when_dest_has_another_consumer() {
|
||||
rt.set_data(src, &[99.0]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt.set_data(dest, &[10.0, 20.0, 30.0, 40.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
let kernels = rt.debug_kernel_ops();
|
||||
assert!(
|
||||
!kernels.iter().any(|k| k.contains("MetalScatterNoCopy")),
|
||||
@@ -1605,7 +1609,7 @@ fn test_scatter_all_positions() {
|
||||
rt.set_data(src, &[40.0, 30.0, 20.0, 10.0]);
|
||||
rt.set_data(indexes, &[3.0, 2.0, 1.0, 0.0]);
|
||||
rt.set_data(dest, &[1.0, 2.0, 3.0, 4.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -1624,7 +1628,7 @@ fn test_gather_preserves_data_dtype() {
|
||||
let mut rt = MetalRuntime::initialize(());
|
||||
rt.set_data(data, &[1.25, 2.5]);
|
||||
rt.set_data(indexes, &[1.0]);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.allocate_intermediate_buffers(&cx.dyn_map);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
|
||||
@@ -167,7 +167,10 @@ mod tests {
|
||||
let result = gather_rows(data, indices, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// data = [[1,2,3], [4,5,6], [7,8,9], [10,11,12]]
|
||||
rt.set_data(
|
||||
@@ -193,7 +196,10 @@ mod tests {
|
||||
let result = scatter_rows(src, indices, dest, 3).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
rt.set_data(src.id, vec![10., 20., 30., 40., 50., 60.]);
|
||||
rt.set_data(indices.id, vec![1, 3]);
|
||||
@@ -219,7 +225,10 @@ mod tests {
|
||||
let gathered = gather_rows(updated_cache, gather_idx, 4).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
rt.set_data(kv_new.id, vec![1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||
rt.set_data(scatter_idx.id, vec![1, 4]); // Write to slots 1 and 4
|
||||
@@ -272,7 +281,10 @@ mod tests {
|
||||
let v_cache_new = v_cache_new.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Q = [1, 0, 1, 0] → head0=[1,0], head1=[1,0]
|
||||
rt.set_data(q.id, vec![1., 0., 1., 0.]);
|
||||
@@ -345,7 +357,10 @@ mod tests {
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Setup: 1 cached token at slot 0, 1 new token written to slot 1
|
||||
// K cached at slot 0: [1, 0]
|
||||
@@ -417,7 +432,10 @@ mod tests {
|
||||
let attn_out = attn_out.output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Cache has 1 token at slot 0
|
||||
let mut k_cache_data = vec![0.; num_slots * kv_dim];
|
||||
|
||||
@@ -184,7 +184,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![1.0, 2.0, 3.0];
|
||||
// Router strongly favors expert 0
|
||||
@@ -239,7 +242,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![1.0, 1.0];
|
||||
// Nearly-equal routing to all experts (slight differences to avoid argsort ties)
|
||||
@@ -293,7 +299,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = vec![
|
||||
1.0, 0.0, 0.0, // batch 0: routes to expert via feature 0
|
||||
@@ -350,7 +359,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = random_vec(in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
@@ -395,7 +407,10 @@ mod tests {
|
||||
let output = moe.forward(input).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let input_data = random_vec(batch * in_dim);
|
||||
let router_data = random_vec(in_dim * n_experts);
|
||||
|
||||
@@ -53,6 +53,10 @@ fn env_usize(name: &str, default: usize) -> usize {
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn search_options() -> CompileOptions {
|
||||
CompileOptions::default().search_graph_limit(env_usize("SEARCH_ITERS", 5))
|
||||
}
|
||||
|
||||
fn env_f32(name: &str, default: f32) -> f32 {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
@@ -187,7 +191,7 @@ fn run_text_encoder(prompt: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>
|
||||
|
||||
println!("Compiling text encoder...");
|
||||
let t0 = Instant::now();
|
||||
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
println!("Encoding prompt...");
|
||||
@@ -345,10 +349,9 @@ fn run_full_pipeline(
|
||||
{
|
||||
use rand::SeedableRng;
|
||||
let mut rng = rand::rngs::SmallRng::seed_from_u64(seed);
|
||||
let opts = luminal::graph::CompileOptions::new(env_usize("SEARCH_ITERS", 5));
|
||||
runtime = cx.search_with_rng(runtime, opts, &mut rng);
|
||||
runtime = cx.search_with_rng(runtime, search_options(), &mut rng);
|
||||
} else {
|
||||
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
}
|
||||
println!(" compile done in {:.1}s", t0.elapsed().as_secs_f64());
|
||||
|
||||
@@ -415,7 +418,7 @@ fn run_full_pipeline(
|
||||
let mut runtime = CudaRuntime::initialize(stream);
|
||||
runtime.load_safetensors(&cx, vae_path.to_str().unwrap());
|
||||
runtime.set_data(latent_in, vae_input);
|
||||
runtime = cx.search(runtime, CompileOptions::new(env_usize("SEARCH_ITERS", 5)));
|
||||
runtime = cx.search(runtime, search_options());
|
||||
runtime.execute(&cx.dyn_map);
|
||||
let img = runtime.get_f32(out);
|
||||
// VaeDecoder output is in roughly [-1, 1] range. Diffusers'
|
||||
|
||||
@@ -720,6 +720,10 @@ mod tests {
|
||||
out
|
||||
}
|
||||
|
||||
fn one_search() -> CompileOptions {
|
||||
CompileOptions::default().search_graph_limit(1)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_bias_matches_reference() {
|
||||
let mut cx = Graph::default();
|
||||
@@ -747,7 +751,7 @@ mod tests {
|
||||
);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
@@ -766,7 +770,7 @@ mod tests {
|
||||
let expected = reference_nearest_upsample_2x(&input, 2, 3, 4);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -790,7 +794,7 @@ mod tests {
|
||||
cx.build_search_space::<CudaRuntime>(CompileOptions::default());
|
||||
let mut rt = CudaRuntime::initialize(ctx.default_stream());
|
||||
rt.set_data(input_t, input);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, one_search());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected);
|
||||
@@ -821,7 +825,7 @@ mod tests {
|
||||
);
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), one_search());
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
@@ -864,7 +868,7 @@ mod tests {
|
||||
rt.set_data(input_t, input);
|
||||
rt.set_data(weight_t, weight);
|
||||
rt.set_data(bias_t, bias);
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, one_search());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
assert_close(&rt.get_f32(out.id), &expected);
|
||||
|
||||
@@ -84,7 +84,8 @@ fn main() {
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(token_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
|
||||
@@ -36,14 +36,16 @@ impl KVCache {
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -68,114 +70,11 @@ pub struct Gemma {
|
||||
|
||||
impl Gemma {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut w = vec![];
|
||||
for l in 0..LAYERS {
|
||||
let is_local = (l + 1) % SLIDING_WINDOW_PATTERN != 0;
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(Q_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, Q_DIM),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
w.push(GemmaLayer {
|
||||
up,
|
||||
gate,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
input_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{l}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
is_local,
|
||||
rope_theta: if is_local {
|
||||
ROPE_THETA_LOCAL
|
||||
} else {
|
||||
ROPE_THETA_GLOBAL
|
||||
},
|
||||
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
|
||||
});
|
||||
}
|
||||
let lm_norm = gemma_norm(HIDDEN, "model.norm.weight", cx);
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers: w,
|
||||
lm_norm,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS).map(|l| GemmaLayer::init(cx, l)).collect(),
|
||||
lm_norm: gemma_norm(HIDDEN, "model.norm.weight", cx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,11 +84,7 @@ impl Gemma {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -226,6 +121,114 @@ struct GemmaLayer {
|
||||
rope_scaling_factor: f32,
|
||||
}
|
||||
|
||||
impl GemmaLayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
let is_local = !(l + 1).is_multiple_of(SLIDING_WINDOW_PATTERN);
|
||||
Self {
|
||||
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
|
||||
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
|
||||
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
|
||||
input_layernorm: layer_norm(cx, l, "input_layernorm"),
|
||||
post_attention_layernorm: layer_norm(cx, l, "post_attention_layernorm"),
|
||||
pre_feedforward_layernorm: layer_norm(cx, l, "pre_feedforward_layernorm"),
|
||||
post_feedforward_layernorm: layer_norm(cx, l, "post_feedforward_layernorm"),
|
||||
is_local,
|
||||
rope_theta: if is_local {
|
||||
ROPE_THETA_LOCAL
|
||||
} else {
|
||||
ROPE_THETA_GLOBAL
|
||||
},
|
||||
rope_scaling_factor: if is_local { 1.0 } else { 8.0 },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = gemma_rotary_embeddings(
|
||||
qk_norm(q, self.q_norm, N_HEADS),
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
let k_rope = gemma_rotary_embeddings(
|
||||
qk_norm(k, self.k_norm, N_KV_HEADS),
|
||||
pos_ids,
|
||||
N_KV_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache_in,
|
||||
v_cache_in,
|
||||
max_seq,
|
||||
self.is_local,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = x + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let x_ff = self.pre_feedforward_layernorm.forward(x);
|
||||
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
|
||||
.matmul(self.down.t());
|
||||
(
|
||||
x + self.post_feedforward_layernorm.forward(mlp_out),
|
||||
k_cache_out,
|
||||
v_cache_out,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
|
||||
gemma_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
/// GELU using the identity: 0.5*x*(1+tanh(a)) = x*sigmoid(2*a)
|
||||
/// This produces far fewer e-graph nodes than the tanh-based expansion.
|
||||
#[allow(clippy::excessive_precision)]
|
||||
@@ -363,59 +366,3 @@ fn hlir_attention(
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl GemmaLayer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// QK-norm + RoPE
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
|
||||
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
|
||||
let q_rope = gemma_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
let k_rope = gemma_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
N_KV_HEADS,
|
||||
self.rope_theta,
|
||||
self.rope_scaling_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache_in,
|
||||
v_cache_in,
|
||||
max_seq,
|
||||
self.is_local,
|
||||
);
|
||||
|
||||
// O projection + post-attention norm + residual
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let attn_normed = self.post_attention_layernorm.forward(attn_proj);
|
||||
let x = x + attn_normed;
|
||||
|
||||
// Pre-feedforward norm + MLP + post-feedforward norm + residual
|
||||
let x_ff = self.pre_feedforward_layernorm.forward(x);
|
||||
let mlp_out = (gemma_gelu(x_ff.matmul(self.gate.t())) * x_ff.matmul(self.up.t()))
|
||||
.matmul(self.down.t());
|
||||
let mlp_normed = self.post_feedforward_layernorm.forward(mlp_out);
|
||||
(x + mlp_normed, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,11 +81,10 @@ fn main() {
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_with_rng(
|
||||
runtime,
|
||||
CompileOptions::new(search_graphs).profile_timeout(Duration::from_secs(2)),
|
||||
&mut rng,
|
||||
);
|
||||
let search_options = CompileOptions::default()
|
||||
.search_graph_limit(search_graphs)
|
||||
.profile_timeout(Duration::from_secs(2));
|
||||
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
|
||||
|
||||
for layer in 0..LAYERS {
|
||||
let cache_bytes = cache_bytes_for_layer(layer, max_seq_len);
|
||||
|
||||
@@ -83,20 +83,16 @@ impl KVCache {
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let k = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
)
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{layer}.k"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{layer}.v"),
|
||||
(spec.num_kv_heads, max_seq, spec.head_dim),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -115,169 +111,13 @@ pub struct Gemma4MoE {
|
||||
|
||||
impl Gemma4MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = Vec::with_capacity(LAYERS);
|
||||
for layer in 0..LAYERS {
|
||||
let spec = layer_spec(layer);
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_proj.weight"),
|
||||
(spec.q_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = spec.has_v_proj.then(|| {
|
||||
cx.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.v_proj.weight"),
|
||||
(spec.kv_dim, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
});
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, spec.q_dim),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.q_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.self_attn.k_norm.weight"),
|
||||
spec.head_dim,
|
||||
)
|
||||
.persist();
|
||||
let layer_scalar = cx
|
||||
.named_tensor(format!("model.layers.{layer}.layer_scalar"), HIDDEN)
|
||||
.persist();
|
||||
|
||||
let router_scale = cx
|
||||
.named_tensor(format!("model.layers.{layer}.router.scale"), HIDDEN)
|
||||
.persist();
|
||||
let router_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.proj.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let per_expert_scale = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.router.per_expert_scale"),
|
||||
NUM_EXPERTS,
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.gate_up_proj"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{layer}.experts.down_proj"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist()
|
||||
.as_dtype(DType::Bf16);
|
||||
|
||||
layers.push(Gemma4Layer {
|
||||
spec,
|
||||
gate,
|
||||
up,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
layer_scalar,
|
||||
input_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.input_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_attention_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_attention_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_1: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_1.weight"),
|
||||
cx,
|
||||
),
|
||||
post_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.post_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
pre_feedforward_layernorm_2: gemma4_norm(
|
||||
HIDDEN,
|
||||
&format!("model.layers.{layer}.pre_feedforward_layernorm_2.weight"),
|
||||
cx,
|
||||
),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale,
|
||||
router_proj,
|
||||
per_expert_scale,
|
||||
gate_up_weights,
|
||||
down_weights,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_norm = gemma4_norm(HIDDEN, "model.norm.weight", cx);
|
||||
|
||||
Self {
|
||||
embedding,
|
||||
lm_head,
|
||||
layers,
|
||||
lm_norm,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS)
|
||||
.map(|layer| Gemma4Layer::init(cx, layer))
|
||||
.collect(),
|
||||
lm_norm: gemma4_norm(HIDDEN, "model.norm.weight", cx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -287,11 +127,7 @@ impl Gemma4MoE {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
@@ -342,6 +178,164 @@ struct Gemma4SparseMoE {
|
||||
down_weights: GraphTensor,
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
fn init(cx: &mut Graph, layer: usize) -> Self {
|
||||
let spec = layer_spec(layer);
|
||||
Self {
|
||||
spec,
|
||||
gate: layer_weight(cx, layer, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
up: layer_weight(cx, layer, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, layer, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, layer, "self_attn.q_proj", (spec.q_dim, HIDDEN)),
|
||||
k_proj: layer_weight(cx, layer, "self_attn.k_proj", (spec.kv_dim, HIDDEN)),
|
||||
v_proj: spec
|
||||
.has_v_proj
|
||||
.then(|| layer_weight(cx, layer, "self_attn.v_proj", (spec.kv_dim, HIDDEN))),
|
||||
o_proj: layer_weight(cx, layer, "self_attn.o_proj", (HIDDEN, spec.q_dim)),
|
||||
q_norm: layer_weight(cx, layer, "self_attn.q_norm", spec.head_dim),
|
||||
k_norm: layer_weight(cx, layer, "self_attn.k_norm", spec.head_dim),
|
||||
layer_scalar: layer_tensor(cx, layer, "layer_scalar", HIDDEN),
|
||||
input_layernorm: layer_norm(cx, layer, "input_layernorm"),
|
||||
post_attention_layernorm: layer_norm(cx, layer, "post_attention_layernorm"),
|
||||
pre_feedforward_layernorm: layer_norm(cx, layer, "pre_feedforward_layernorm"),
|
||||
post_feedforward_layernorm: layer_norm(cx, layer, "post_feedforward_layernorm"),
|
||||
post_feedforward_layernorm_1: layer_norm(cx, layer, "post_feedforward_layernorm_1"),
|
||||
post_feedforward_layernorm_2: layer_norm(cx, layer, "post_feedforward_layernorm_2"),
|
||||
pre_feedforward_layernorm_2: layer_norm(cx, layer, "pre_feedforward_layernorm_2"),
|
||||
moe: Gemma4SparseMoE {
|
||||
router_scale: layer_tensor(cx, layer, "router.scale", HIDDEN),
|
||||
router_proj: layer_weight(cx, layer, "router.proj", (NUM_EXPERTS, HIDDEN)),
|
||||
per_expert_scale: layer_tensor(cx, layer, "router.per_expert_scale", NUM_EXPERTS),
|
||||
gate_up_weights: layer_tensor(
|
||||
cx,
|
||||
layer,
|
||||
"experts.gate_up_proj",
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
down_weights: layer_tensor(
|
||||
cx,
|
||||
layer,
|
||||
"experts.down_proj",
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_tensor(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn layer_norm(cx: &mut Graph, layer: usize, name: &str) -> LayerNorm {
|
||||
gemma4_norm(HIDDEN, &format!("model.layers.{layer}.{name}.weight"), cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn gemma4_norm(dim: usize, weight_name: &str, cx: &mut Graph) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(weight_name), None, false, RMS_NORM_EPS, cx)
|
||||
}
|
||||
@@ -505,81 +499,6 @@ fn hlir_attention(
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl Gemma4Layer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let residual = x;
|
||||
let x_attn = self.input_layernorm.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k_base = x_attn.matmul(self.k_proj.t());
|
||||
let v_base = if let Some(v_proj) = self.v_proj {
|
||||
x_attn.matmul(v_proj.t())
|
||||
} else {
|
||||
k_base
|
||||
};
|
||||
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS, self.spec.head_dim);
|
||||
let k_normed = qk_norm(
|
||||
k_base,
|
||||
self.k_norm,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
);
|
||||
let v_normed = value_norm(v_base, self.spec.head_dim);
|
||||
|
||||
let q_rope = gemma4_rotary_embeddings(
|
||||
q_normed,
|
||||
pos_ids,
|
||||
N_HEADS,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
let k_rope = gemma4_rotary_embeddings(
|
||||
k_normed,
|
||||
pos_ids,
|
||||
self.spec.num_kv_heads,
|
||||
self.spec.head_dim,
|
||||
self.spec.rope_theta,
|
||||
self.spec.partial_rotary_factor,
|
||||
);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = hlir_attention(
|
||||
q_rope, k_rope, v_normed, k_cache_in, v_cache_in, max_seq, self.spec,
|
||||
);
|
||||
|
||||
let attn_proj = attn_out.matmul(self.o_proj.t());
|
||||
let x = residual + self.post_attention_layernorm.forward(attn_proj);
|
||||
|
||||
let dense_ff = dense_ffn(
|
||||
self.pre_feedforward_layernorm.forward(x),
|
||||
self.gate,
|
||||
self.up,
|
||||
self.down,
|
||||
);
|
||||
let dense_ff = self.post_feedforward_layernorm_1.forward(dense_ff);
|
||||
|
||||
let moe_out = self
|
||||
.moe
|
||||
.forward(x, self.pre_feedforward_layernorm_2.forward(x));
|
||||
let moe_out = self.post_feedforward_layernorm_2.forward(moe_out);
|
||||
|
||||
let ff_out = self.post_feedforward_layernorm.forward(dense_ff + moe_out);
|
||||
let x = x + ff_out;
|
||||
let x = x * self
|
||||
.layer_scalar
|
||||
.expand_lhs(&x.dims()[..x.dims().len() - 1]);
|
||||
|
||||
(x, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn dense_ffn(x: GraphTensor, gate: GraphTensor, up: GraphTensor, down: GraphTensor) -> GraphTensor {
|
||||
(gemma_gelu(x.matmul(gate.t())) * x.matmul(up.t())).matmul(down.t())
|
||||
}
|
||||
|
||||
@@ -338,13 +338,11 @@ fn main() {
|
||||
println!(" Search trials: {SEARCH_TRIALS}");
|
||||
println!(" Search keep-best: {SEARCH_KEEP_BEST}");
|
||||
let mut rng = StdRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_with_rng(
|
||||
runtime,
|
||||
CompileOptions::new(SEARCH_GRAPHS)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST),
|
||||
&mut rng,
|
||||
);
|
||||
let search_options = CompileOptions::default()
|
||||
.search_graph_limit(SEARCH_GRAPHS)
|
||||
.trials(SEARCH_TRIALS)
|
||||
.keep_best(SEARCH_KEEP_BEST);
|
||||
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
|
||||
println!(
|
||||
" Search/compile: {:.2} s",
|
||||
compile_start.elapsed().as_secs_f64()
|
||||
|
||||
@@ -111,125 +111,18 @@ impl Llama {
|
||||
config: LlamaConfig,
|
||||
fp8_linears: bool,
|
||||
) -> Self {
|
||||
let mut layers = Vec::with_capacity(config.layers);
|
||||
for l in 0..config.layers {
|
||||
layers.push(LlamaLayer {
|
||||
config,
|
||||
up: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.up_proj"),
|
||||
(config.intermediate, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
up_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.up_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
gate: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.gate_proj"),
|
||||
(config.intermediate, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
gate_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.gate_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
down: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.down_proj"),
|
||||
(config.hidden, config.intermediate),
|
||||
fp8_linears,
|
||||
),
|
||||
down_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.mlp.down_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
q_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.q_proj"),
|
||||
(config.hidden, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
q_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.q_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
k_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.k_proj"),
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
k_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.k_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
v_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.v_proj"),
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
v_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.v_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
o_proj: linear_weight(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.o_proj"),
|
||||
(config.hidden, config.hidden),
|
||||
fp8_linears,
|
||||
),
|
||||
o_proj_scales: fp8_linear_scales(
|
||||
cx,
|
||||
format!("model.layers.{l}.self_attn.o_proj"),
|
||||
fp8_linears,
|
||||
),
|
||||
attn_rms: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
Self {
|
||||
config,
|
||||
embedding: cx
|
||||
.named_tensor(
|
||||
"model.embed_tokens.weight",
|
||||
(config.vocab_size, config.hidden),
|
||||
)
|
||||
.persist(),
|
||||
layers,
|
||||
lm_head: cx
|
||||
.named_tensor("lm_head.weight", (config.vocab_size, config.hidden))
|
||||
.persist(),
|
||||
lm_norm: LayerNorm::new(
|
||||
config.hidden,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
embedding: persist(
|
||||
cx,
|
||||
"model.embed_tokens.weight",
|
||||
(config.vocab_size, config.hidden),
|
||||
),
|
||||
layers: (0..config.layers)
|
||||
.map(|l| LlamaLayer::init(cx, l, config, fp8_linears))
|
||||
.collect(),
|
||||
lm_head: persist(cx, "lm_head.weight", (config.vocab_size, config.hidden)),
|
||||
lm_norm: rms_norm(cx, config.hidden, "model.norm.weight"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,12 +136,7 @@ impl Llama {
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let hidden = self.config.hidden;
|
||||
let mut x = self.embedding.gather(
|
||||
(input * hidden).expand_dim(1, hidden)
|
||||
+ input.graph().arange(hidden).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, input, self.config.hidden);
|
||||
let mut cache_outputs = Vec::with_capacity(self.config.layers);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -311,6 +199,170 @@ struct Fp8LinearScales {
|
||||
weight: GraphTensor,
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
fn init(cx: &mut Graph, l: usize, config: LlamaConfig, fp8: bool) -> Self {
|
||||
Self {
|
||||
config,
|
||||
up: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.up_proj",
|
||||
(config.intermediate, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
up_scales: layer_linear_scales(cx, l, "mlp.up_proj", fp8),
|
||||
gate: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.gate_proj",
|
||||
(config.intermediate, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
gate_scales: layer_linear_scales(cx, l, "mlp.gate_proj", fp8),
|
||||
down: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"mlp.down_proj",
|
||||
(config.hidden, config.intermediate),
|
||||
fp8,
|
||||
),
|
||||
down_scales: layer_linear_scales(cx, l, "mlp.down_proj", fp8),
|
||||
q_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.q_proj",
|
||||
(config.hidden, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
q_proj_scales: layer_linear_scales(cx, l, "self_attn.q_proj", fp8),
|
||||
k_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.k_proj",
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8,
|
||||
),
|
||||
k_proj_scales: layer_linear_scales(cx, l, "self_attn.k_proj", fp8),
|
||||
v_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.v_proj",
|
||||
(config.kv_dim(), config.hidden),
|
||||
fp8,
|
||||
),
|
||||
v_proj_scales: layer_linear_scales(cx, l, "self_attn.v_proj", fp8),
|
||||
o_proj: layer_linear_weight(
|
||||
cx,
|
||||
l,
|
||||
"self_attn.o_proj",
|
||||
(config.hidden, config.hidden),
|
||||
fp8,
|
||||
),
|
||||
o_proj_scales: layer_linear_scales(cx, l, "self_attn.o_proj", fp8),
|
||||
attn_rms: rms_norm(
|
||||
cx,
|
||||
config.hidden,
|
||||
format!("model.layers.{l}.input_layernorm.weight"),
|
||||
),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
config.hidden,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
|
||||
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
|
||||
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
AttentionInputs {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
},
|
||||
self.config,
|
||||
);
|
||||
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
|
||||
* linear_matmul(x_mlp, self.up, self.up_scales);
|
||||
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn parameter_tensors(&self) -> Vec<GraphTensor> {
|
||||
let mut tensors = vec![
|
||||
self.up,
|
||||
self.gate,
|
||||
self.down,
|
||||
self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj,
|
||||
];
|
||||
for scales in [
|
||||
self.up_scales,
|
||||
self.gate_scales,
|
||||
self.down_scales,
|
||||
self.q_proj_scales,
|
||||
self.k_proj_scales,
|
||||
self.v_proj_scales,
|
||||
self.o_proj_scales,
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
tensors.push(scales.input);
|
||||
tensors.push(scales.weight);
|
||||
}
|
||||
if let Some(weight) = self.attn_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.attn_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
if let Some(weight) = self.mlp_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.mlp_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn linear_weight(
|
||||
cx: &mut Graph,
|
||||
prefix: impl ToString,
|
||||
@@ -325,6 +377,16 @@ fn linear_weight(
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_linear_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
fp8: bool,
|
||||
) -> GraphTensor {
|
||||
linear_weight(cx, format!("model.layers.{layer}.{suffix}"), shape, fp8)
|
||||
}
|
||||
|
||||
fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option<Fp8LinearScales> {
|
||||
if !fp8 {
|
||||
return None;
|
||||
@@ -340,6 +402,27 @@ fn fp8_linear_scales(cx: &mut Graph, prefix: impl ToString, fp8: bool) -> Option
|
||||
})
|
||||
}
|
||||
|
||||
fn layer_linear_scales(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
fp8: bool,
|
||||
) -> Option<Fp8LinearScales> {
|
||||
fp8_linear_scales(cx, format!("model.layers.{layer}.{suffix}"), fp8)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, dim: usize, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(dim, Some(&weight_name.to_string()), None, false, 1e-5, cx)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor, hidden: usize) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * hidden).expand_dim(1, hidden)
|
||||
+ token_ids.graph().arange(hidden).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn expand_scalar(scale: GraphTensor, like: GraphTensor) -> GraphTensor {
|
||||
scale.expand_rhs(like.dims())
|
||||
}
|
||||
@@ -443,87 +526,3 @@ fn attention(
|
||||
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = linear_matmul(x_attn, self.q_proj, self.q_proj_scales);
|
||||
let k = linear_matmul(x_attn, self.k_proj, self.k_proj_scales);
|
||||
let v = linear_matmul(x_attn, self.v_proj, self.v_proj_scales);
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos, self.config);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos, self.config);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = attention(
|
||||
AttentionInputs {
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
},
|
||||
self.config,
|
||||
);
|
||||
x += linear_matmul(attn_out, self.o_proj, self.o_proj_scales);
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out = linear_matmul(x_mlp, self.gate, self.gate_scales).swish()
|
||||
* linear_matmul(x_mlp, self.up, self.up_scales);
|
||||
let mlp_out = linear_matmul(mlp_out, self.down, self.down_scales);
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn parameter_tensors(&self) -> Vec<GraphTensor> {
|
||||
let mut tensors = vec![
|
||||
self.up,
|
||||
self.gate,
|
||||
self.down,
|
||||
self.q_proj,
|
||||
self.k_proj,
|
||||
self.v_proj,
|
||||
self.o_proj,
|
||||
];
|
||||
for scales in [
|
||||
self.up_scales,
|
||||
self.gate_scales,
|
||||
self.down_scales,
|
||||
self.q_proj_scales,
|
||||
self.k_proj_scales,
|
||||
self.v_proj_scales,
|
||||
self.o_proj_scales,
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
{
|
||||
tensors.push(scales.input);
|
||||
tensors.push(scales.weight);
|
||||
}
|
||||
if let Some(weight) = self.attn_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.attn_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
if let Some(weight) = self.mlp_rms.weight {
|
||||
tensors.push(weight);
|
||||
}
|
||||
if let Some(bias) = self.mlp_rms.bias {
|
||||
tensors.push(bias);
|
||||
}
|
||||
tensors
|
||||
}
|
||||
}
|
||||
|
||||
@@ -241,7 +241,8 @@ fn main() {
|
||||
runtime.set_data(scatter_idx_t, vec![0i32; search_s]);
|
||||
runtime.set_data(gather_idx_t, vec![0i32; search_c]);
|
||||
runtime.set_data(attn_mask_t, vec![0.0f32; search_s * search_c]);
|
||||
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
// Re-initialize KV cache after search (search consumes buffers)
|
||||
let cache_bytes = num_slots * KV_DIM * std::mem::size_of::<f32>();
|
||||
|
||||
@@ -25,8 +25,8 @@ pub struct PagedKVCache {
|
||||
|
||||
impl PagedKVCache {
|
||||
pub fn new(cx: &mut Graph, num_slots: usize) -> Self {
|
||||
let mut k_caches = vec![];
|
||||
let mut v_caches = vec![];
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
k_caches.push(cx.named_tensor(format!("kv_cache.{l}.k"), (num_slots, KV_DIM)));
|
||||
v_caches.push(cx.named_tensor(format!("kv_cache.{l}.v"), (num_slots, KV_DIM)));
|
||||
@@ -44,78 +44,11 @@ pub struct Llama {
|
||||
|
||||
impl Llama {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = vec![];
|
||||
for l in 0..LAYERS {
|
||||
layers.push(LlamaLayer {
|
||||
up: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
gate: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
down: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist(),
|
||||
q_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
o_proj: cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, HIDDEN),
|
||||
)
|
||||
.persist(),
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
Self {
|
||||
embedding: cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
layers,
|
||||
lm_head: cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist(),
|
||||
lm_norm: LayerNorm::new(HIDDEN, Some("model.norm.weight"), None, false, 1e-5, cx),
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS).map(|l| LlamaLayer::init(cx, l)).collect(),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
lm_norm: rms_norm(cx, "model.norm.weight"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,12 +74,8 @@ impl Llama {
|
||||
attn_mask: GraphTensor,
|
||||
kv_cache: &PagedKVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = input.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(input * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ input.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut cache_outputs = vec![];
|
||||
let mut x = token_embedding(self.embedding, input);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
x,
|
||||
@@ -177,6 +106,99 @@ struct LlamaLayer {
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
Self {
|
||||
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (HIDDEN, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, HIDDEN)),
|
||||
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&weight_name.to_string()),
|
||||
None,
|
||||
false,
|
||||
1e-5,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
fn llama_rotary_embeddings(mut input: GraphTensor, pos_ids: GraphTensor) -> GraphTensor {
|
||||
input = input.split_dims(1, HEAD_DIM).transpose(0, 1);
|
||||
|
||||
@@ -264,44 +286,3 @@ fn paged_attention(
|
||||
let attn_out = out.transpose(0, 1).merge_dims(1, 2);
|
||||
(attn_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl LlamaLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
q_pos: GraphTensor,
|
||||
scatter_idx: GraphTensor,
|
||||
gather_idx: GraphTensor,
|
||||
attn_mask: GraphTensor,
|
||||
k_cache: GraphTensor,
|
||||
v_cache: GraphTensor,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// Apply RoPE before scattering into cache
|
||||
let q_rope = llama_rotary_embeddings(q, q_pos);
|
||||
let k_rope = llama_rotary_embeddings(k, q_pos);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) = paged_attention(
|
||||
q_rope,
|
||||
k_rope,
|
||||
v,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scatter_idx,
|
||||
gather_idx,
|
||||
attn_mask,
|
||||
);
|
||||
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +188,8 @@ where
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_i32_data(input.id, vec![1; search_s]);
|
||||
runtime.set_i32_data(token_ids.id, (0..search_s as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, CompileOptions::new(config.search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(config.search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
for i in 0..config.layers {
|
||||
runtime.set_zeros(kv_cache.k_caches[i].id, cache_bytes);
|
||||
|
||||
@@ -34,14 +34,16 @@ impl KVCache {
|
||||
let mut k_caches = Vec::with_capacity(layers);
|
||||
let mut v_caches = Vec::with_capacity(layers);
|
||||
for l in 0..layers {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -63,105 +65,10 @@ impl Qwen {
|
||||
layers <= LAYERS,
|
||||
"requested {layers} layers, but model has {LAYERS}"
|
||||
);
|
||||
let mut w = vec![];
|
||||
for l in 0..layers {
|
||||
let up = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.up_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let gate = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_proj.weight"),
|
||||
(INTERMEDIATE, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_proj.weight"),
|
||||
(HIDDEN, INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(Q_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, Q_DIM),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
w.push(QwenLayer {
|
||||
up,
|
||||
gate,
|
||||
down,
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
});
|
||||
}
|
||||
let lm_norm = LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
);
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
Self {
|
||||
embedding,
|
||||
layers: w,
|
||||
lm_norm,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..layers).map(|l| QwenLayer::init(cx, l)).collect(),
|
||||
lm_norm: rms_norm(cx, "model.norm.weight"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,11 +79,7 @@ impl Qwen {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
let mut cache_outputs = Vec::with_capacity(self.layers.len());
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -209,6 +112,90 @@ struct QwenLayer {
|
||||
mlp_rms: LayerNorm,
|
||||
}
|
||||
|
||||
impl QwenLayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
Self {
|
||||
up: layer_weight(cx, l, "mlp.up_proj", (INTERMEDIATE, HIDDEN)),
|
||||
gate: layer_weight(cx, l, "mlp.gate_proj", (INTERMEDIATE, HIDDEN)),
|
||||
down: layer_weight(cx, l, "mlp.down_proj", (HIDDEN, INTERMEDIATE)),
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
|
||||
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
|
||||
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
|
||||
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
let q_rope = qwen_rotary_embeddings(qk_norm(q, self.q_norm, N_HEADS), pos_ids, N_HEADS);
|
||||
let k_rope =
|
||||
qwen_rotary_embeddings(qk_norm(k, self.k_norm, N_KV_HEADS), pos_ids, N_KV_HEADS);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) =
|
||||
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&weight_name.to_string()),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
/// Per-head RMS normalization for QK-norm.
|
||||
/// Input: [seq, dim] where dim = n_heads * HEAD_DIM
|
||||
/// split_dims to [seq, n_heads, HEAD_DIM], RMS norm over last axis, multiply by weight, merge back.
|
||||
@@ -331,36 +318,3 @@ fn hlir_attention(
|
||||
|
||||
(out, k_cache_out, v_cache_out)
|
||||
}
|
||||
|
||||
impl QwenLayer {
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
pos_ids: GraphTensor,
|
||||
k_cache_in: GraphTensor,
|
||||
v_cache_in: GraphTensor,
|
||||
max_seq: usize,
|
||||
) -> (GraphTensor, GraphTensor, GraphTensor) {
|
||||
let x_attn = self.attn_rms.forward(x);
|
||||
let q = x_attn.matmul(self.q_proj.t());
|
||||
let k = x_attn.matmul(self.k_proj.t());
|
||||
let v = x_attn.matmul(self.v_proj.t());
|
||||
|
||||
// QK-norm: per-head RMS normalization
|
||||
let q_normed = qk_norm(q, self.q_norm, N_HEADS);
|
||||
let k_normed = qk_norm(k, self.k_norm, N_KV_HEADS);
|
||||
|
||||
// RoPE
|
||||
let q_rope = qwen_rotary_embeddings(q_normed, pos_ids, N_HEADS);
|
||||
let k_rope = qwen_rotary_embeddings(k_normed, pos_ids, N_KV_HEADS);
|
||||
|
||||
let (attn_out, k_cache_out, v_cache_out) =
|
||||
hlir_attention(q_rope, k_rope, v, k_cache_in, v_cache_in, max_seq);
|
||||
x += attn_out.matmul(self.o_proj.t());
|
||||
|
||||
let x_mlp = self.mlp_rms.forward(x);
|
||||
let mlp_out =
|
||||
(x_mlp.matmul(self.gate.t()).swish() * x_mlp.matmul(self.up.t())).matmul(self.down.t());
|
||||
(x + mlp_out, k_cache_out, v_cache_out)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,7 +82,8 @@ fn main() {
|
||||
runtime.set_data(input, vec![1; search_s]);
|
||||
runtime.set_data(pos_ids, (0..search_s as i32).collect::<Vec<_>>());
|
||||
let mut rng = SmallRng::seed_from_u64(SEARCH_SEED);
|
||||
runtime = cx.search_with_rng(runtime, CompileOptions::new(search_graphs), &mut rng);
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search_with_rng(runtime, search_options, &mut rng);
|
||||
|
||||
for i in 0..LAYERS {
|
||||
runtime.set_zeros(kv_cache.k_caches[i], cache_bytes);
|
||||
|
||||
@@ -29,17 +29,19 @@ pub struct KVCache {
|
||||
|
||||
impl KVCache {
|
||||
pub fn new(cx: &mut Graph, max_seq: usize) -> Self {
|
||||
let mut k_caches = vec![];
|
||||
let mut v_caches = vec![];
|
||||
let mut k_caches = Vec::with_capacity(LAYERS);
|
||||
let mut v_caches = Vec::with_capacity(LAYERS);
|
||||
for l in 0..LAYERS {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_KV_HEADS, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_KV_HEADS, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -58,111 +60,11 @@ pub struct Qwen3MoE {
|
||||
|
||||
impl Qwen3MoE {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
let mut layers = vec![];
|
||||
for l in 0..LAYERS {
|
||||
let q_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_proj.weight"),
|
||||
(Q_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let k_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let v_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.v_proj.weight"),
|
||||
(KV_DIM, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let o_proj = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.o_proj.weight"),
|
||||
(HIDDEN, Q_DIM),
|
||||
)
|
||||
.persist();
|
||||
let q_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.q_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let k_norm = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.self_attn.k_norm.weight"),
|
||||
HEAD_DIM,
|
||||
)
|
||||
.persist();
|
||||
let router = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate.weight"),
|
||||
(NUM_EXPERTS, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let gate_up_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.gate_up_weights"),
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.persist();
|
||||
let down_weights = cx
|
||||
.named_tensor(
|
||||
format!("model.layers.{l}.mlp.down_weights"),
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.persist();
|
||||
layers.push(Qwen3MoELayer {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
attn_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.input_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
mlp_rms: LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&format!("model.layers.{l}.post_attention_layernorm.weight")),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
),
|
||||
moe: QwenMoE {
|
||||
router,
|
||||
gate_up_weights: gate_up_weights.as_dtype(DType::Bf16),
|
||||
down_weights: down_weights.as_dtype(DType::Bf16),
|
||||
},
|
||||
});
|
||||
}
|
||||
let lm_norm = LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some("model.norm.weight"),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
);
|
||||
let embedding = cx
|
||||
.named_tensor("model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
let lm_head = cx
|
||||
.named_tensor("lm_head.weight", (VOCAB_SIZE, HIDDEN))
|
||||
.persist();
|
||||
Self {
|
||||
embedding,
|
||||
layers,
|
||||
lm_norm,
|
||||
lm_head,
|
||||
embedding: persist(cx, "model.embed_tokens.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
layers: (0..LAYERS).map(|l| Qwen3MoELayer::init(cx, l)).collect(),
|
||||
lm_norm: rms_norm(cx, "model.norm.weight"),
|
||||
lm_head: persist(cx, "lm_head.weight", (VOCAB_SIZE, HIDDEN)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,11 +74,7 @@ impl Qwen3MoE {
|
||||
pos_ids: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
let mut x = self.embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
);
|
||||
let mut x = token_embedding(self.embedding, token_ids);
|
||||
let mut cache_outputs = Vec::with_capacity(LAYERS);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let (x_new, k_out, v_out) = layer.forward(
|
||||
@@ -214,6 +112,39 @@ struct QwenMoE {
|
||||
}
|
||||
|
||||
impl Qwen3MoELayer {
|
||||
fn init(cx: &mut Graph, l: usize) -> Self {
|
||||
Self {
|
||||
q_proj: layer_weight(cx, l, "self_attn.q_proj", (Q_DIM, HIDDEN)),
|
||||
k_proj: layer_weight(cx, l, "self_attn.k_proj", (KV_DIM, HIDDEN)),
|
||||
v_proj: layer_weight(cx, l, "self_attn.v_proj", (KV_DIM, HIDDEN)),
|
||||
o_proj: layer_weight(cx, l, "self_attn.o_proj", (HIDDEN, Q_DIM)),
|
||||
q_norm: layer_weight(cx, l, "self_attn.q_norm", HEAD_DIM),
|
||||
k_norm: layer_weight(cx, l, "self_attn.k_norm", HEAD_DIM),
|
||||
attn_rms: rms_norm(cx, format!("model.layers.{l}.input_layernorm.weight")),
|
||||
mlp_rms: rms_norm(
|
||||
cx,
|
||||
format!("model.layers.{l}.post_attention_layernorm.weight"),
|
||||
),
|
||||
moe: QwenMoE {
|
||||
router: layer_weight(cx, l, "mlp.gate", (NUM_EXPERTS, HIDDEN)),
|
||||
gate_up_weights: layer_tensor(
|
||||
cx,
|
||||
l,
|
||||
"mlp.gate_up_weights",
|
||||
(NUM_EXPERTS, MOE_INTERMEDIATE * 2, HIDDEN),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
down_weights: layer_tensor(
|
||||
cx,
|
||||
l,
|
||||
"mlp.down_weights",
|
||||
(NUM_EXPERTS, HIDDEN, MOE_INTERMEDIATE),
|
||||
)
|
||||
.as_dtype(DType::Bf16),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&self,
|
||||
mut x: GraphTensor,
|
||||
@@ -247,6 +178,51 @@ impl Qwen3MoELayer {
|
||||
}
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
fn layer_tensor(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
persist(cx, format!("model.layers.{layer}.{suffix}"), shape)
|
||||
}
|
||||
|
||||
fn layer_weight(
|
||||
cx: &mut Graph,
|
||||
layer: usize,
|
||||
suffix: &str,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
layer_tensor(cx, layer, &format!("{suffix}.weight"), shape)
|
||||
}
|
||||
|
||||
fn rms_norm(cx: &mut Graph, weight_name: impl ToString) -> LayerNorm {
|
||||
LayerNorm::new(
|
||||
HIDDEN,
|
||||
Some(&weight_name.to_string()),
|
||||
None,
|
||||
false,
|
||||
RMS_NORM_EPS,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
||||
fn token_embedding(embedding: GraphTensor, token_ids: GraphTensor) -> GraphTensor {
|
||||
let seq = token_ids.dims1();
|
||||
embedding.gather(
|
||||
(token_ids * HIDDEN).expand_dim(1, HIDDEN)
|
||||
+ token_ids.graph().arange(HIDDEN).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
impl QwenMoE {
|
||||
fn forward(&self, x: GraphTensor) -> GraphTensor {
|
||||
let n = x.dims().len(); // 2 for [s, H]
|
||||
|
||||
@@ -12,7 +12,7 @@ fn main() {
|
||||
|
||||
// Compile
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default());
|
||||
|
||||
// Set input tensors
|
||||
rt.set_data(a, vec![1.0, 2.0, 3.0]);
|
||||
|
||||
@@ -96,7 +96,8 @@ fn main() {
|
||||
cx.set_dim('p', 0);
|
||||
runtime.set_data(input, vec![1i32; max_prefill]);
|
||||
runtime.set_data(pos_ids, (0..max_prefill as i32).collect::<Vec<_>>());
|
||||
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
|
||||
// Reset the KV caches and re-set the mel after search (which executes test runs).
|
||||
for i in 0..N_TEXT_LAYER {
|
||||
|
||||
@@ -33,6 +33,14 @@ fn linear_no_bias(x: GraphTensor, w: GraphTensor) -> GraphTensor {
|
||||
x.matmul(w.t())
|
||||
}
|
||||
|
||||
fn persist(
|
||||
cx: &mut Graph,
|
||||
name: impl ToString,
|
||||
shape: impl luminal::prelude::ToShape,
|
||||
) -> GraphTensor {
|
||||
cx.named_tensor(name, shape).persist()
|
||||
}
|
||||
|
||||
/// 1D convolution with bias. Input: (ch_in, length). Weight: (ch_out, ch_in*kernel)
|
||||
/// (HF stores it as (ch_out, ch_in, kernel) which flat-loads identically). Output: (ch_out, out_length).
|
||||
fn conv1d_bias(
|
||||
@@ -90,27 +98,13 @@ struct AttentionWeights {
|
||||
impl AttentionWeights {
|
||||
fn new(prefix: &str, dim: usize, cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
q_proj: cx
|
||||
.named_tensor(format!("{prefix}.q_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
q_bias: cx
|
||||
.named_tensor(format!("{prefix}.q_proj.bias"), dim)
|
||||
.persist(),
|
||||
k_proj: cx
|
||||
.named_tensor(format!("{prefix}.k_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
v_proj: cx
|
||||
.named_tensor(format!("{prefix}.v_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
v_bias: cx
|
||||
.named_tensor(format!("{prefix}.v_proj.bias"), dim)
|
||||
.persist(),
|
||||
out_proj: cx
|
||||
.named_tensor(format!("{prefix}.out_proj.weight"), (dim, dim))
|
||||
.persist(),
|
||||
out_bias: cx
|
||||
.named_tensor(format!("{prefix}.out_proj.bias"), dim)
|
||||
.persist(),
|
||||
q_proj: persist(cx, format!("{prefix}.q_proj.weight"), (dim, dim)),
|
||||
q_bias: persist(cx, format!("{prefix}.q_proj.bias"), dim),
|
||||
k_proj: persist(cx, format!("{prefix}.k_proj.weight"), (dim, dim)),
|
||||
v_proj: persist(cx, format!("{prefix}.v_proj.weight"), (dim, dim)),
|
||||
v_bias: persist(cx, format!("{prefix}.v_proj.bias"), dim),
|
||||
out_proj: persist(cx, format!("{prefix}.out_proj.weight"), (dim, dim)),
|
||||
out_bias: persist(cx, format!("{prefix}.out_proj.bias"), dim),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -125,6 +119,14 @@ fn merge_heads(x: GraphTensor) -> GraphTensor {
|
||||
x.transpose(0, 1).merge_dims(1, 2)
|
||||
}
|
||||
|
||||
fn embedding_lookup(embedding: GraphTensor, ids: GraphTensor) -> GraphTensor {
|
||||
let seq = ids.dims1();
|
||||
embedding.gather(
|
||||
(ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
|
||||
+ ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
|
||||
)
|
||||
}
|
||||
|
||||
/// Encoder self-attention (full, non-causal). Input/output shape (seq, dim).
|
||||
fn encoder_self_attention(x: GraphTensor, w: &AttentionWeights) -> GraphTensor {
|
||||
let q = linear_with_bias(x, w.q_proj, w.q_bias);
|
||||
@@ -239,18 +241,10 @@ impl EncoderLayer {
|
||||
N_AUDIO_STATE,
|
||||
cx,
|
||||
),
|
||||
fc1: cx
|
||||
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE))
|
||||
.persist(),
|
||||
fc1_b: cx
|
||||
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
|
||||
.persist(),
|
||||
fc2: cx
|
||||
.named_tensor(format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM))
|
||||
.persist(),
|
||||
fc2_b: cx
|
||||
.named_tensor(format!("{prefix}.fc2.bias"), N_AUDIO_STATE)
|
||||
.persist(),
|
||||
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_AUDIO_STATE)),
|
||||
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
|
||||
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_AUDIO_STATE, FF_DIM)),
|
||||
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_AUDIO_STATE),
|
||||
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_AUDIO_STATE, cx),
|
||||
}
|
||||
}
|
||||
@@ -295,18 +289,10 @@ impl DecoderLayer {
|
||||
N_TEXT_STATE,
|
||||
cx,
|
||||
),
|
||||
fc1: cx
|
||||
.named_tensor(format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE))
|
||||
.persist(),
|
||||
fc1_b: cx
|
||||
.named_tensor(format!("{prefix}.fc1.bias"), FF_DIM)
|
||||
.persist(),
|
||||
fc2: cx
|
||||
.named_tensor(format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM))
|
||||
.persist(),
|
||||
fc2_b: cx
|
||||
.named_tensor(format!("{prefix}.fc2.bias"), N_TEXT_STATE)
|
||||
.persist(),
|
||||
fc1: persist(cx, format!("{prefix}.fc1.weight"), (FF_DIM, N_TEXT_STATE)),
|
||||
fc1_b: persist(cx, format!("{prefix}.fc1.bias"), FF_DIM),
|
||||
fc2: persist(cx, format!("{prefix}.fc2.weight"), (N_TEXT_STATE, FF_DIM)),
|
||||
fc2_b: persist(cx, format!("{prefix}.fc2.bias"), N_TEXT_STATE),
|
||||
final_ln: standard_layernorm(&format!("{prefix}.final_layer_norm"), N_TEXT_STATE, cx),
|
||||
}
|
||||
}
|
||||
@@ -346,14 +332,16 @@ impl KVCache {
|
||||
let mut k_caches = Vec::with_capacity(N_TEXT_LAYER);
|
||||
let mut v_caches = Vec::with_capacity(N_TEXT_LAYER);
|
||||
for l in 0..N_TEXT_LAYER {
|
||||
let k = cx
|
||||
.named_tensor(format!("kv_cache.{l}.k"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
let v = cx
|
||||
.named_tensor(format!("kv_cache.{l}.v"), (N_TEXT_HEAD, max_seq, HEAD_DIM))
|
||||
.persist();
|
||||
k_caches.push(k);
|
||||
v_caches.push(v);
|
||||
k_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.k"),
|
||||
(N_TEXT_HEAD, max_seq, HEAD_DIM),
|
||||
));
|
||||
v_caches.push(persist(
|
||||
cx,
|
||||
format!("kv_cache.{l}.v"),
|
||||
(N_TEXT_HEAD, max_seq, HEAD_DIM),
|
||||
));
|
||||
}
|
||||
Self {
|
||||
k_caches,
|
||||
@@ -376,27 +364,23 @@ pub struct WhisperEncoder {
|
||||
impl WhisperEncoder {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
conv1_w: cx
|
||||
.named_tensor("model.encoder.conv1.weight", (N_AUDIO_STATE, N_MELS * 3))
|
||||
.persist(),
|
||||
conv1_b: cx
|
||||
.named_tensor("model.encoder.conv1.bias", N_AUDIO_STATE)
|
||||
.persist(),
|
||||
conv2_w: cx
|
||||
.named_tensor(
|
||||
"model.encoder.conv2.weight",
|
||||
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
|
||||
)
|
||||
.persist(),
|
||||
conv2_b: cx
|
||||
.named_tensor("model.encoder.conv2.bias", N_AUDIO_STATE)
|
||||
.persist(),
|
||||
positional_embedding: cx
|
||||
.named_tensor(
|
||||
"model.encoder.embed_positions.weight",
|
||||
(N_AUDIO_CTX, N_AUDIO_STATE),
|
||||
)
|
||||
.persist(),
|
||||
conv1_w: persist(
|
||||
cx,
|
||||
"model.encoder.conv1.weight",
|
||||
(N_AUDIO_STATE, N_MELS * 3),
|
||||
),
|
||||
conv1_b: persist(cx, "model.encoder.conv1.bias", N_AUDIO_STATE),
|
||||
conv2_w: persist(
|
||||
cx,
|
||||
"model.encoder.conv2.weight",
|
||||
(N_AUDIO_STATE, N_AUDIO_STATE * 3),
|
||||
),
|
||||
conv2_b: persist(cx, "model.encoder.conv2.bias", N_AUDIO_STATE),
|
||||
positional_embedding: persist(
|
||||
cx,
|
||||
"model.encoder.embed_positions.weight",
|
||||
(N_AUDIO_CTX, N_AUDIO_STATE),
|
||||
),
|
||||
layers: (0..N_AUDIO_LAYER)
|
||||
.map(|i| EncoderLayer::new(i, cx))
|
||||
.collect(),
|
||||
@@ -427,15 +411,16 @@ pub struct WhisperDecoder {
|
||||
impl WhisperDecoder {
|
||||
pub fn init(cx: &mut Graph) -> Self {
|
||||
Self {
|
||||
embed_tokens: cx
|
||||
.named_tensor("model.decoder.embed_tokens.weight", (N_VOCAB, N_TEXT_STATE))
|
||||
.persist(),
|
||||
embed_positions: cx
|
||||
.named_tensor(
|
||||
"model.decoder.embed_positions.weight",
|
||||
(N_TEXT_CTX, N_TEXT_STATE),
|
||||
)
|
||||
.persist(),
|
||||
embed_tokens: persist(
|
||||
cx,
|
||||
"model.decoder.embed_tokens.weight",
|
||||
(N_VOCAB, N_TEXT_STATE),
|
||||
),
|
||||
embed_positions: persist(
|
||||
cx,
|
||||
"model.decoder.embed_positions.weight",
|
||||
(N_TEXT_CTX, N_TEXT_STATE),
|
||||
),
|
||||
layers: (0..N_TEXT_LAYER)
|
||||
.map(|i| DecoderLayer::new(i, cx))
|
||||
.collect(),
|
||||
@@ -450,18 +435,8 @@ impl WhisperDecoder {
|
||||
xa: GraphTensor,
|
||||
kv_cache: &KVCache,
|
||||
) -> (GraphTensor, Vec<(GraphTensor, GraphTensor)>) {
|
||||
let seq = token_ids.dims1();
|
||||
// Token embedding gather
|
||||
let mut x = self.embed_tokens.gather(
|
||||
(token_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
|
||||
+ token_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
|
||||
);
|
||||
// Positional embedding gather (using pos_ids)
|
||||
let pos_emb = self.embed_positions.gather(
|
||||
(pos_ids * N_TEXT_STATE).expand_dim(1, N_TEXT_STATE)
|
||||
+ pos_ids.graph().arange(N_TEXT_STATE).expand_dim(0, seq),
|
||||
);
|
||||
x += pos_emb;
|
||||
let mut x = embedding_lookup(self.embed_tokens, token_ids);
|
||||
x += embedding_lookup(self.embed_positions, pos_ids);
|
||||
|
||||
let mut cache_outputs = Vec::with_capacity(N_TEXT_LAYER);
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
|
||||
@@ -335,7 +335,8 @@ fn main() {
|
||||
|
||||
println!("Compiling (search_graphs={search_graphs})...");
|
||||
let t0 = Instant::now();
|
||||
runtime = cx.search(runtime, CompileOptions::new(search_graphs));
|
||||
let search_options = CompileOptions::default().search_graph_limit(search_graphs);
|
||||
runtime = cx.search(runtime, search_options);
|
||||
println!(" search took {:?}", t0.elapsed());
|
||||
|
||||
// Re-set anchors/strides/dfl/img after search (search may consume the inputs)
|
||||
|
||||
@@ -174,7 +174,10 @@ pub fn compile_backend<Rt: Runtime + 'static>(
|
||||
}
|
||||
|
||||
// Search
|
||||
let mut rt = graph.search(rt, CompileOptions::new(args.search_iters));
|
||||
let mut rt = graph.search(
|
||||
rt,
|
||||
CompileOptions::default().search_graph_limit(args.search_iters),
|
||||
);
|
||||
|
||||
// Rebuild label map after search (graph may have changed)
|
||||
let label_map = build_label_map(graph);
|
||||
|
||||
@@ -463,7 +463,10 @@ pub(super) mod tests {
|
||||
let c = func(a, b).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let lhs_values = lhs_transform(random_vec(a_shape.iter().copied().product()));
|
||||
let rhs_values = rhs_transform(random_vec(b_shape.iter().copied().product()));
|
||||
|
||||
@@ -968,7 +968,10 @@ mod tests {
|
||||
let zeros = cx.iota(Expression::from(0usize), 6);
|
||||
let inv = values.scatter(perm, zeros).cast(DType::F32).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(data.id, vec![0., 1., 2., 3., 4., 5.]);
|
||||
rt.set_data(indexes.id, vec![5, 0, 3, 2]);
|
||||
rt.set_data(perm.id, vec![3, 2, 4, 1, 5, 0]);
|
||||
@@ -985,7 +988,10 @@ mod tests {
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(src.id, vec![10., 20., 30.]);
|
||||
rt.set_data(indexes.id, vec![1, 3, 4]);
|
||||
rt.set_data(dest.id, vec![0., 0., 0., 0., 0.]);
|
||||
@@ -1001,7 +1007,10 @@ mod tests {
|
||||
let dest = cx.tensor(5);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(src.id, vec![99.]);
|
||||
rt.set_data(indexes.id, vec![2]);
|
||||
rt.set_data(dest.id, vec![1., 2., 3., 4., 5.]);
|
||||
@@ -1017,7 +1026,10 @@ mod tests {
|
||||
let dest = cx.tensor(4);
|
||||
let result = src.scatter(indexes, dest).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(src.id, vec![40., 30., 20., 10.]);
|
||||
rt.set_data(indexes.id, vec![3, 2, 1, 0]);
|
||||
rt.set_data(dest.id, vec![1., 2., 3., 4.]);
|
||||
@@ -1045,7 +1057,10 @@ mod tests {
|
||||
let repeated = (a.repeat((2, 2)) * 1.0).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(a.id, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
|
||||
@@ -126,7 +126,10 @@ mod tests {
|
||||
let b = func(&mut cx).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -211,7 +214,10 @@ mod tests {
|
||||
let stacked = cx.stack(&[a, b, c], 0).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let a_data = random_vec(6);
|
||||
let b_data = random_vec(6);
|
||||
|
||||
@@ -493,7 +493,10 @@ pub(super) mod tests {
|
||||
let b = func(a).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
let v = random_vec(shape.iter().copied().product());
|
||||
rt.set_data(a.id, v.clone());
|
||||
|
||||
165
src/graph.rs
165
src/graph.rs
@@ -137,7 +137,9 @@ impl DimBucket {
|
||||
/// Use the builder pattern to configure search parameters:
|
||||
/// ```
|
||||
/// use luminal::prelude::CompileOptions;
|
||||
/// let opts = CompileOptions::new(5)
|
||||
/// let opts = CompileOptions::default()
|
||||
/// .search_graph_limit(5)
|
||||
/// .search_time_limit(std::time::Duration::from_secs(30))
|
||||
/// .generation_size(50)
|
||||
/// .mutations(40)
|
||||
/// .trials(15);
|
||||
@@ -146,6 +148,8 @@ impl DimBucket {
|
||||
pub struct CompileOptions {
|
||||
/// Maximum number of graphs to evaluate during search.
|
||||
pub limit: usize,
|
||||
/// Maximum wall-clock time to spend searching.
|
||||
pub search_time_limit: std::time::Duration,
|
||||
/// Optional maximum runtime-specific intermediate memory, in bytes.
|
||||
///
|
||||
/// When this is `None`, search does not apply a memory cap. Runtimes that
|
||||
@@ -163,8 +167,6 @@ pub struct CompileOptions {
|
||||
/// Per-candidate profiling timeout. If a profile call reaches this budget,
|
||||
/// that candidate is discarded and search continues.
|
||||
pub profile_timeout: Option<std::time::Duration>,
|
||||
/// Optional per-group search timeout.
|
||||
pub group_timeout: Option<std::time::Duration>,
|
||||
/// Optional profiling dimension overrides.
|
||||
pub profile_dims: FxHashMap<char, usize>,
|
||||
/// Bucket definitions per dynamic dimension. Dimensions without buckets use
|
||||
@@ -173,20 +175,16 @@ pub struct CompileOptions {
|
||||
}
|
||||
|
||||
impl CompileOptions {
|
||||
/// Create compile options with the given search limit. Other fields use defaults.
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
limit,
|
||||
max_memory_bytes: None,
|
||||
generation_size: 10,
|
||||
mutations: 10,
|
||||
trials: 3,
|
||||
keep_best: 1,
|
||||
profile_timeout: Some(std::time::Duration::from_secs(1)),
|
||||
group_timeout: None,
|
||||
profile_dims: FxHashMap::default(),
|
||||
dim_buckets: FxHashMap::default(),
|
||||
}
|
||||
/// Set the maximum number of graphs to evaluate during search.
|
||||
pub fn search_graph_limit(mut self, limit: usize) -> Self {
|
||||
self.limit = limit;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum wall-clock time to spend searching.
|
||||
pub fn search_time_limit(mut self, search_time_limit: std::time::Duration) -> Self {
|
||||
self.search_time_limit = search_time_limit;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a maximum intermediate memory budget in bytes.
|
||||
@@ -235,12 +233,6 @@ impl CompileOptions {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set an optional per-group search timeout.
|
||||
pub fn group_timeout(mut self, group_timeout: std::time::Duration) -> Self {
|
||||
self.group_timeout = Some(group_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Override a dynamic dimension value used during search profiling.
|
||||
pub fn profile_dim(mut self, dim: char, value: usize) -> Self {
|
||||
self.profile_dims.insert(dim, value);
|
||||
@@ -261,7 +253,18 @@ impl CompileOptions {
|
||||
|
||||
impl Default for CompileOptions {
|
||||
fn default() -> Self {
|
||||
Self::new(1)
|
||||
Self {
|
||||
limit: 100,
|
||||
search_time_limit: std::time::Duration::MAX,
|
||||
max_memory_bytes: None,
|
||||
generation_size: 10,
|
||||
mutations: 10,
|
||||
trials: 3,
|
||||
keep_best: 1,
|
||||
profile_timeout: Some(std::time::Duration::from_secs(1)),
|
||||
profile_dims: FxHashMap::default(),
|
||||
dim_buckets: FxHashMap::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1302,10 +1305,18 @@ impl Graph {
|
||||
"dim buckets must be configured in CompileOptions before build_search_space; search cannot change buckets after build",
|
||||
);
|
||||
|
||||
let search_started_at = std::time::Instant::now();
|
||||
if self.search_space_dim_buckets.is_empty() {
|
||||
// No buckets: existing single-search path
|
||||
let stitched =
|
||||
self.search_single(&mut runtime, &options, rng, &self.dyn_map.clone(), None, 0);
|
||||
let stitched = self.search_single(
|
||||
&mut runtime,
|
||||
&options,
|
||||
rng,
|
||||
&self.dyn_map.clone(),
|
||||
None,
|
||||
0,
|
||||
search_started_at,
|
||||
);
|
||||
|
||||
runtime.clear_intermediate_buffers();
|
||||
runtime.load_llir(&stitched);
|
||||
@@ -1339,6 +1350,7 @@ impl Graph {
|
||||
&context.representative_dyn_map,
|
||||
Some((combo_idx, n_combos)),
|
||||
combo_idx,
|
||||
search_started_at,
|
||||
);
|
||||
bucket_llirs.push((
|
||||
context.bucket_indices,
|
||||
@@ -1425,6 +1437,7 @@ impl Graph {
|
||||
/// Run the genetic search and return the unrolled LLIR for the winning
|
||||
/// genome. `bucket_progress`: if `Some((current_bucket_idx, total_buckets))`
|
||||
/// adds a second "Bucket" progress bar.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn search_single<R: Runtime + 'static, G: rand::Rng>(
|
||||
&mut self,
|
||||
runtime: &mut R,
|
||||
@@ -1433,6 +1446,7 @@ impl Graph {
|
||||
dyn_map: &FxHashMap<char, usize>,
|
||||
bucket_progress: Option<(usize, usize)>,
|
||||
egraph_index: usize,
|
||||
search_started_at: std::time::Instant,
|
||||
) -> LLIRGraph {
|
||||
let mut profile_dyn_map = dyn_map.clone();
|
||||
for (&dim, &value) in &options.profile_dims {
|
||||
@@ -1482,11 +1496,11 @@ impl Graph {
|
||||
}
|
||||
};
|
||||
|
||||
let group_start = std::time::Instant::now();
|
||||
let mut prev_selected: FxHashSet<u64> = FxHashSet::default();
|
||||
let mut list_cache = FxHashMap::default();
|
||||
let mut expr_cache = FxHashMap::default();
|
||||
runtime.clear_intermediate_buffers();
|
||||
let search_time_limit_reached = || search_started_at.elapsed() >= options.search_time_limit;
|
||||
let profile_timed_out = |elapsed: std::time::Duration| {
|
||||
options
|
||||
.profile_timeout
|
||||
@@ -1561,11 +1575,8 @@ impl Graph {
|
||||
break;
|
||||
}
|
||||
Ok(_) | Err(_) => {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
panic!("Failed to find a viable initial genome before timeout");
|
||||
if search_time_limit_reached() {
|
||||
panic!("Failed to find a viable initial genome before search time limit");
|
||||
}
|
||||
list_cache.clear();
|
||||
expr_cache.clear();
|
||||
@@ -1586,10 +1597,7 @@ impl Graph {
|
||||
let mut resample_generation = false;
|
||||
|
||||
while n_graphs < search_limit {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
if search_time_limit_reached() {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1623,10 +1631,7 @@ impl Graph {
|
||||
let mut generation_found_non_timeout = false;
|
||||
|
||||
for genome in all_offspring {
|
||||
if options
|
||||
.group_timeout
|
||||
.is_some_and(|timeout| group_start.elapsed() >= timeout)
|
||||
{
|
||||
if search_time_limit_reached() {
|
||||
break;
|
||||
}
|
||||
n_graphs += 1;
|
||||
@@ -2808,6 +2813,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::egglog_utils::hash_egglog_normalized;
|
||||
use crate::tests::{assert_close, random_vec};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestMemoryRuntime;
|
||||
@@ -2845,6 +2851,69 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
static PROFILE_CALLS: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[derive(Default)]
|
||||
struct CountingRuntime;
|
||||
|
||||
impl Runtime for CountingRuntime {
|
||||
type Ops = ();
|
||||
type CompileArg = ();
|
||||
type ExecReturn = ();
|
||||
type ProfileMetric = usize;
|
||||
|
||||
fn initialize(_: Self::CompileArg) -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
fn load_llir(&mut self, _: &LLIRGraph) {}
|
||||
|
||||
fn execute(&mut self, _: &FxHashMap<char, usize>) -> Self::ExecReturn {}
|
||||
|
||||
fn profile(
|
||||
&mut self,
|
||||
_: &LLIRGraph,
|
||||
_: &FxHashMap<char, usize>,
|
||||
_: usize,
|
||||
_: Option<std::time::Duration>,
|
||||
) -> (Self::ProfileMetric, String) {
|
||||
let count = PROFILE_CALLS.fetch_add(1, Ordering::SeqCst);
|
||||
(count, format!("{count} ms"))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_options_defaults_and_search_time_limit_builder() {
|
||||
let opts = CompileOptions::default();
|
||||
assert_eq!(opts.limit, 100);
|
||||
assert_eq!(opts.search_time_limit, std::time::Duration::MAX);
|
||||
|
||||
let time_limit = std::time::Duration::from_millis(25);
|
||||
let opts = CompileOptions::default()
|
||||
.search_graph_limit(7)
|
||||
.search_time_limit(time_limit);
|
||||
assert_eq!(opts.limit, 7);
|
||||
assert_eq!(opts.search_time_limit, time_limit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_time_limit_stops_after_initial_viable_candidate() {
|
||||
let mut cx = Graph::new();
|
||||
let a = cx.tensor((4, 8));
|
||||
let b = cx.tensor((8, 4));
|
||||
let c = cx.tensor((4, 4));
|
||||
let _ = (a.matmul(b) + c).relu().softmax(1).output();
|
||||
|
||||
cx.build_search_space::<CountingRuntime>(CompileOptions::default());
|
||||
|
||||
PROFILE_CALLS.store(0, Ordering::SeqCst);
|
||||
let _ = cx.search(
|
||||
CountingRuntime,
|
||||
CompileOptions::default().search_time_limit(std::time::Duration::ZERO),
|
||||
);
|
||||
assert_eq!(PROFILE_CALLS.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_search_space_without_explicit_max_memory_has_no_cap() {
|
||||
let mut cx = Graph::new();
|
||||
@@ -2862,7 +2931,10 @@ mod tests {
|
||||
let _ = cx.tensor(1).output();
|
||||
cx.build_search_space::<TestMemoryRuntime>(CompileOptions::default().max_memory_bytes(0));
|
||||
|
||||
let _ = cx.search(TestMemoryRuntime, CompileOptions::new(1));
|
||||
let _ = cx.search(
|
||||
TestMemoryRuntime,
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2870,7 +2942,10 @@ mod tests {
|
||||
let mut cx = Graph::new();
|
||||
let _ = cx.tensor(1).output();
|
||||
|
||||
let _ = cx.compile(TestMemoryRuntime, CompileOptions::new(1));
|
||||
let _ = cx.compile(
|
||||
TestMemoryRuntime,
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
assert_eq!(cx.egraphs.len(), 1);
|
||||
assert_eq!(cx.search_space_max_memory_bytes, None);
|
||||
@@ -2908,7 +2983,7 @@ mod tests {
|
||||
let mut rng = rand::rng();
|
||||
let _ = cx.search_with_rng(
|
||||
TestMemoryRuntime,
|
||||
CompileOptions::new(1).dim_buckets(
|
||||
CompileOptions::default().search_graph_limit(1).dim_buckets(
|
||||
's',
|
||||
&[DimBucket::new(1, 1), DimBucket::new(2, 4).representative(4)],
|
||||
),
|
||||
@@ -3019,7 +3094,7 @@ mod tests {
|
||||
let vals = random_vec(8);
|
||||
let mut rt = NativeRuntime::default();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.set_data(x.id, vals.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -3056,7 +3131,7 @@ mod tests {
|
||||
let vals = random_vec(8);
|
||||
let mut rt = NativeRuntime::default();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
rt = cx.search(rt, CompileOptions::new(1));
|
||||
rt = cx.search(rt, CompileOptions::default().search_graph_limit(1));
|
||||
rt.set_data(x.id, vals.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ proptest! {
|
||||
let d = (b * c / e).sin().output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
|
||||
|
||||
rt.set_data(b.id, vals.clone());
|
||||
rt.set_data(c.id, vals.clone());
|
||||
@@ -58,7 +58,7 @@ proptest! {
|
||||
let a = b.matmul(c).output();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
|
||||
let lhs = lhs.into_iter().take(m * k).collect::<Vec<f32>>();
|
||||
let rhs = rhs.into_iter().take(k * n).collect::<Vec<f32>>();
|
||||
rt.set_data(b.id, lhs.clone());
|
||||
@@ -82,7 +82,7 @@ proptest! {
|
||||
let a = cx.tensor((2, 2));
|
||||
let b = (a.permute((1, 0)) * 1.0).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
|
||||
rt.set_data(a.id, values.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
|
||||
@@ -99,7 +99,7 @@ proptest! {
|
||||
let mask = a.ge(kth_largest.expand_dim(1, cols)).cast(crate::dtype::DType::F32);
|
||||
let filtered = (a * mask).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::default().search_graph_limit(1));
|
||||
let values = values.into_iter().take(rows * cols).collect::<Vec<f32>>();
|
||||
rt.set_data(a.id, values.clone());
|
||||
rt.execute(&cx.dyn_map);
|
||||
@@ -465,7 +465,10 @@ fn test_inputs_consumed_after_execute() {
|
||||
let a = cx.tensor(3);
|
||||
let _b = (a * 2.0).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(a.id, vec![1.0, 2.0, 3.0]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
// Second execute should panic — input 'a' was consumed
|
||||
@@ -481,7 +484,10 @@ fn test_passthrough_preserves_weights() {
|
||||
w.persist();
|
||||
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
|
||||
// Iteration 1
|
||||
rt.set_data(w.id, vec![1.0, 2.0, 3.0]);
|
||||
@@ -503,7 +509,10 @@ fn test_only_outputs_remain() {
|
||||
let a = cx.tensor(3);
|
||||
let _b = (a * 2.0).output();
|
||||
cx.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = cx.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = cx.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(a.id, vec![1.0, 2.0, 3.0]);
|
||||
rt.execute(&cx.dyn_map);
|
||||
let output_count = rt
|
||||
@@ -556,7 +565,10 @@ fn integration_auto_loop_rolling_matches_reference_native_runtime() {
|
||||
|
||||
let (mut graph, input_id, weight_ids, output_id) = build_repeated_block_graph(layers, width);
|
||||
graph.build_search_space::<NativeRuntime>(CompileOptions::default());
|
||||
let mut rt = graph.search(NativeRuntime::default(), CompileOptions::new(1));
|
||||
let mut rt = graph.search(
|
||||
NativeRuntime::default(),
|
||||
CompileOptions::default().search_graph_limit(1),
|
||||
);
|
||||
rt.set_data(input_id, input);
|
||||
for (node, data) in weight_ids.iter().zip(weights.iter()) {
|
||||
rt.set_data(*node, data.clone());
|
||||
|
||||
Reference in New Issue
Block a user