Wire GPU trace postprocessing into trace session

This commit is contained in:
Joe Fioti
2026-01-01 12:33:26 -05:00
parent 0082fedd3c
commit d3fbc58173
7 changed files with 186 additions and 63 deletions

View File

@@ -31,6 +31,10 @@ egglog = "1.0.0"
egglog-ast = "1.0.0"
egraph-serialize = { version = "0.3.0", default-features = false, features = ["graphviz", "serde"]}
tracing = "0.1.43"
tracing-appender = "0.2.4"
tracing-perfetto-sdk-layer = "0.13.0"
tracing-perfetto-sdk-schema = "0.13.0"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
paste = "1.0.15"
pretty-duration = "0.1.1"
anyhow = "1.0"

View File

@@ -798,9 +798,9 @@ pub fn allocate_input_buffers(
pub fn record_exec_timings_to_file(
timings: &Vec<(Vec<SMEvent>, u64)>,
ops: &Vec<Arc<Box<dyn BlockOp>>>,
file_path: &str,
file_path: impl AsRef<std::path::Path>,
) {
let data = std::fs::read(file_path).unwrap();
let data = std::fs::read(&file_path).unwrap();
let mut trace = tracing_perfetto_sdk_schema::Trace::decode(data.as_slice()).unwrap();
let host_start_times: Vec<(u64, u32)> = trace

View File

@@ -5,35 +5,23 @@ use luminal::{
graph::{Graph, Runtime},
op::DType,
prelude::FxHashMap,
trace::{self, TraceOptions},
};
use luminal_cuda::{
block::IntoBlockOp,
runtime::{record_exec_timings_to_file, CudaRuntime, CustomState},
};
use model::*;
use std::{fs::File, io::Write, time::Duration};
use std::io::Write;
use tokenizers::Tokenizer;
use tracing::{span, Level};
use tracing_appender::non_blocking;
use tracing_perfetto_sdk_layer::NativeLayer;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
fn main() {
// Set up tracing
let file = File::create("trace.pftrace").unwrap();
let (writer, _guard) = non_blocking(file);
let layer = NativeLayer::from_config(trace_config(), writer)
.build()
.unwrap();
let filter = EnvFilter::builder()
.parse(format!("{}=trace,luminal=trace", env!("CARGO_PKG_NAME")))
.unwrap();
let layer_handle = layer.clone();
tracing_subscriber::registry()
.with(filter)
.with(layer)
.init();
let trace_session = trace::init(TraceOptions {
sink: trace::trace_file_path("trace.pftrace"),
env_filter: format!("{}=trace,luminal=trace", env!("CARGO_PKG_NAME")),
});
let max_seq_len = 4096;
let gen_tokens = 5;
@@ -129,16 +117,15 @@ fn main() {
}
println!();
layer_handle
.flush(Duration::from_secs(5), Duration::from_secs(5))
.unwrap();
layer_handle.stop().unwrap();
drop(_guard);
record_exec_timings_to_file(
&timings,
&<luminal_cuda::block::Ops as IntoBlockOp>::into_vec(),
"trace.pftrace",
);
trace_session.flush();
trace_session.stop();
if let Some(path) = trace_session.perfetto_path() {
record_exec_timings_to_file(
&timings,
&<luminal_cuda::block::Ops as IntoBlockOp>::into_vec(),
path,
);
}
}
#[tracing::instrument(skip_all)]
@@ -157,20 +144,3 @@ fn sample(logits: &[f32], vocab_size: usize) -> Vec<u32> {
})
.collect()
}
fn trace_config() -> tracing_perfetto_sdk_schema::TraceConfig {
tracing_perfetto_sdk_schema::TraceConfig {
buffers: vec![tracing_perfetto_sdk_schema::trace_config::BufferConfig {
size_kb: Some(4096),
..Default::default()
}],
data_sources: vec![tracing_perfetto_sdk_schema::trace_config::DataSource {
config: Some(tracing_perfetto_sdk_schema::DataSourceConfig {
name: Some("rust_tracing".into()),
..Default::default()
}),
..Default::default()
}],
..Default::default()
}
}

View File

@@ -17,6 +17,7 @@ use egraph_serialize::{ClassId, NodeId};
use itertools::Itertools;
use petgraph::{Direction, stable_graph::StableGraph, visit::EdgeRef};
use rustc_hash::{FxHashMap, FxHashSet};
use tracing::info;
pub type LLIRGraph = StableGraph<LLIROp, (), petgraph::Directed>;
pub type HLIRGraph = StableGraph<Box<dyn HLIROp>, Dependency>;
@@ -198,6 +199,7 @@ impl Graph {
let print = std::env::var("SEARCH")
.map(|s| s == "1")
.unwrap_or_default();
let limit_reached = llir_graphs.len() == limit;
let start = std::time::Instant::now();
if print {
println!(
@@ -205,17 +207,21 @@ impl Graph {
format!(
"---- Searching through {}{} graphs ----",
llir_graphs.len().to_string().bold(),
if llir_graphs.len() == limit {
"[limit]"
} else {
""
}
if limit_reached { "[limit]" } else { "" }
)
.cyan()
);
}
runtime.compile(llir_graphs.last().unwrap());
if print {
info!(
target: "luminal::search",
graphs = llir_graphs.len(),
limit,
limit_reached,
duration_ms = start.elapsed().as_millis() as u64,
"search completed"
);
println!(
"{}",
format!(
@@ -392,17 +398,22 @@ fn run_egglog(
.unwrap_or_default()
{
println!("{}", "---- Egglog Rule Matches ----".green());
println!(
"{}",
egraph
.get_overall_run_report()
.num_matches_per_rule
.iter()
.filter(|(k, _)| !k.contains("("))
.map(|(k, v)| format!("{k}: {v}"))
.join("\n")
.green()
);
let mut rule_lines = Vec::new();
for (rule, matches) in egraph
.get_overall_run_report()
.num_matches_per_rule
.iter()
.filter(|(k, _)| !k.contains("("))
{
info!(
target: "luminal::egglog",
rule = %rule,
matches = *matches,
"rule matches"
);
rule_lines.push(format!("{rule}: {matches}"));
}
println!("{}", rule_lines.join("\n").green());
println!(
"{}",
format!(
@@ -411,6 +422,11 @@ fn run_egglog(
)
.green()
);
info!(
target: "luminal::egglog",
duration_ms = start.elapsed().as_millis() as u64,
"egglog run completed"
);
}
let (sort, value) = egraph.eval_expr(&var!(root)).unwrap();

View File

@@ -5,6 +5,7 @@ pub mod hl_ops;
pub mod op;
pub mod serialized_egraph;
pub mod shape;
pub mod trace;
pub mod utils;
pub mod visualization;

View File

@@ -17,6 +17,7 @@ use itertools::Itertools;
use num_traits::Float;
use petgraph::{Direction, algo::toposort, prelude::StableGraph, visit::EdgeRef};
use rustc_hash::FxHashMap;
use tracing::info_span;
pub type Ops = (
Input,
@@ -1510,6 +1511,8 @@ impl Runtime for NativeRuntime {
continue;
}
let span = info_span!("native_op", op = %format!("{:?}", self.graph[node]));
let _entered = span.enter();
let inputs = self
.graph
.edges_directed(node, Direction::Incoming)

129
src/trace.rs Normal file
View File

@@ -0,0 +1,129 @@
use std::{
fs::File,
path::{Path, PathBuf},
time::Duration,
};
use tracing_appender::non_blocking::{self, WorkerGuard};
use tracing_perfetto_sdk_layer::NativeLayer;
use tracing_perfetto_sdk_schema::{
DataSourceConfig, TraceConfig,
trace_config::{BufferConfig, DataSource},
};
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
pub enum TraceSink {
PerfettoFile { path: PathBuf },
Stdout,
Disabled,
}
pub struct TraceOptions {
pub sink: TraceSink,
pub env_filter: String,
}
impl Default for TraceOptions {
fn default() -> Self {
Self {
sink: TraceSink::Stdout,
env_filter: "luminal=trace".to_string(),
}
}
}
pub struct TraceSession {
perfetto_layer: Option<tracing_perfetto_sdk_layer::LayerHandle>,
guard: Option<WorkerGuard>,
perfetto_path: Option<PathBuf>,
}
impl TraceSession {
pub fn flush(&self) {
if let Some(layer) = &self.perfetto_layer {
let _ = layer.flush(Duration::from_secs(5), Duration::from_secs(5));
}
}
pub fn stop(&self) {
if let Some(layer) = &self.perfetto_layer {
let _ = layer.stop();
}
}
pub fn perfetto_path(&self) -> Option<&Path> {
self.perfetto_path.as_deref()
}
}
pub fn init(options: TraceOptions) -> TraceSession {
let filter = EnvFilter::builder()
.parse(options.env_filter)
.expect("Invalid tracing env filter");
match options.sink {
TraceSink::PerfettoFile { path } => init_perfetto_file(&filter, path),
TraceSink::Stdout => init_stdout(&filter),
TraceSink::Disabled => {
tracing_subscriber::registry().with(filter).init();
TraceSession {
perfetto_layer: None,
guard: None,
perfetto_path: None,
}
}
}
}
fn init_perfetto_file(filter: &EnvFilter, path: PathBuf) -> TraceSession {
let file = File::create(&path).expect("Failed to create trace file");
let (writer, guard) = non_blocking(file);
let layer = NativeLayer::from_config(default_perfetto_config(), writer)
.build()
.expect("Failed to build perfetto layer");
let handle = layer.clone();
tracing_subscriber::registry()
.with(filter.clone())
.with(layer)
.init();
TraceSession {
perfetto_layer: Some(handle),
guard: Some(guard),
perfetto_path: Some(path),
}
}
fn init_stdout(filter: &EnvFilter) -> TraceSession {
tracing_subscriber::registry()
.with(filter.clone())
.with(tracing_subscriber::fmt::layer())
.init();
TraceSession {
perfetto_layer: None,
guard: None,
perfetto_path: None,
}
}
fn default_perfetto_config() -> TraceConfig {
TraceConfig {
buffers: vec![BufferConfig {
size_kb: Some(4096),
..Default::default()
}],
data_sources: vec![DataSource {
config: Some(DataSourceConfig {
name: Some("rust_tracing".into()),
..Default::default()
}),
..Default::default()
}],
..Default::default()
}
}
pub fn trace_file_path(path: impl AsRef<Path>) -> TraceSink {
TraceSink::PerfettoFile {
path: path.as_ref().to_path_buf(),
}
}