mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
* Move ONNX import into burn-onnx crate * Update publish * Update burn-import -> burn-onnx * Fix clippy warnings that are no longer allowed * Allow unused * Update contributor book references to burn-onnx Update all burn-import path references in the ONNX development guide to point to the new burn-onnx crate location. * Remove ONNX integration steps from burn op guide Deleted the section detailing how to add a new operation to burn-onnx, including ONNX IR and code generation mapping steps. This streamlines the guide and removes outdated or redundant ONNX-specific instructions. * Update onnx-ir references from burn-import to burn-onnx Update documentation and code comments to reference the new burn-onnx crate instead of burn-import. * Update ONNX test producer name to burn-onnx-test Update producer_name metadata in Python test scripts from "burn-import-test" to "burn-onnx-test" for consistency. * Undo ONNX file changes
400 lines
14 KiB
Python
Executable File
400 lines
14 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Download and prepare the Silero VAD model for testing.
|
|
|
|
This script downloads the Silero VAD ONNX model (opset 18, if-less version) and prepares
|
|
it for use with burn-onnx. This version has only 1 If node (for sample rate selection)
|
|
making it compatible with static type inference.
|
|
|
|
See: https://github.com/snakers4/silero-vad/issues/728 for compatibility discussion.
|
|
"""
|
|
|
|
import json
|
|
import struct
|
|
import urllib.request
|
|
import wave
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
import onnx
|
|
except ImportError:
|
|
print("Error: onnx package not found. Please install it with:")
|
|
print(" pip install onnx")
|
|
exit(1)
|
|
|
|
try:
|
|
import onnxruntime as ort
|
|
except ImportError:
|
|
print("Error: onnxruntime package not found. Please install it with:")
|
|
print(" pip install onnxruntime")
|
|
exit(1)
|
|
|
|
|
|
def extract_node_info(model_path, artifacts_dir):
|
|
"""Extract node types and configurations from the ONNX model."""
|
|
print("Extracting node information from ONNX model...")
|
|
|
|
# Load the ONNX model (without external data since we only need structure)
|
|
model = onnx.load(str(model_path), load_external_data=False)
|
|
|
|
# Check for external data
|
|
external_files = set()
|
|
for init in model.graph.initializer:
|
|
if init.data_location == onnx.TensorProto.EXTERNAL:
|
|
for ext_data in init.external_data:
|
|
if ext_data.key == 'location':
|
|
external_files.add(ext_data.value)
|
|
|
|
if external_files:
|
|
print(f"⚠️ Model requires external data files: {external_files}")
|
|
print(" These files are missing from the repository!")
|
|
|
|
# Collect node information
|
|
node_types = defaultdict(int)
|
|
node_details = []
|
|
|
|
def process_graph(graph, graph_name="main"):
|
|
"""Recursively process a graph and its subgraphs."""
|
|
for idx, node in enumerate(graph.node):
|
|
node_types[node.op_type] += 1
|
|
|
|
# Extract node details
|
|
node_info = {
|
|
"graph": graph_name,
|
|
"index": idx,
|
|
"op_type": node.op_type,
|
|
"name": node.name if node.name else f"{node.op_type}_{idx}",
|
|
"inputs": list(node.input),
|
|
"outputs": list(node.output),
|
|
"attributes": {}
|
|
}
|
|
|
|
# Extract attributes
|
|
for attr in node.attribute:
|
|
attr_name = attr.name
|
|
# Get attribute value based on type
|
|
if attr.HasField('f'):
|
|
node_info["attributes"][attr_name] = float(attr.f)
|
|
elif attr.HasField('i'):
|
|
node_info["attributes"][attr_name] = int(attr.i)
|
|
elif attr.HasField('s'):
|
|
node_info["attributes"][attr_name] = attr.s.decode('utf-8') if attr.s else ""
|
|
elif attr.HasField('t'):
|
|
node_info["attributes"][attr_name] = "<tensor>"
|
|
elif attr.floats:
|
|
node_info["attributes"][attr_name] = list(attr.floats)
|
|
elif attr.ints:
|
|
node_info["attributes"][attr_name] = list(attr.ints)
|
|
elif attr.strings:
|
|
node_info["attributes"][attr_name] = [s.decode('utf-8') for s in attr.strings]
|
|
elif attr.HasField('g'):
|
|
# Subgraph - recursively process it
|
|
subgraph_name = f"{graph_name}.{node.op_type}_{idx}.{attr_name}"
|
|
node_info["attributes"][attr_name] = f"<subgraph: {subgraph_name}>"
|
|
process_graph(attr.g, subgraph_name)
|
|
elif attr.graphs:
|
|
subgraph_names = []
|
|
for g_idx, subgraph in enumerate(attr.graphs):
|
|
subgraph_name = f"{graph_name}.{node.op_type}_{idx}.{attr_name}_{g_idx}"
|
|
subgraph_names.append(subgraph_name)
|
|
process_graph(subgraph, subgraph_name)
|
|
node_info["attributes"][attr_name] = f"<subgraphs: {', '.join(subgraph_names)}>"
|
|
else:
|
|
node_info["attributes"][attr_name] = "<unknown>"
|
|
|
|
node_details.append(node_info)
|
|
|
|
# Process the main graph
|
|
process_graph(model.graph, "main")
|
|
|
|
# Create summary
|
|
summary = {
|
|
"model_name": model.graph.name,
|
|
"opset_version": model.opset_import[0].version if model.opset_import else "unknown",
|
|
"total_nodes": len(node_details),
|
|
"node_type_counts": dict(sorted(node_types.items())),
|
|
"external_data_files": list(external_files),
|
|
"nodes": node_details
|
|
}
|
|
|
|
# Save to JSON file
|
|
output_path = artifacts_dir / "node_info.json"
|
|
with open(output_path, 'w') as f:
|
|
json.dump(summary, f, indent=2)
|
|
|
|
print(f"✓ Node information extracted to {output_path}")
|
|
print(f" Opset version: {summary['opset_version']}")
|
|
print(f" Total nodes: {summary['total_nodes']}")
|
|
print(f" Unique node types: {len(node_types)}")
|
|
print(f" Node type distribution:")
|
|
for op_type, count in sorted(node_types.items(), key=lambda x: x[1], reverse=True):
|
|
print(f" - {op_type}: {count}")
|
|
|
|
return summary
|
|
|
|
|
|
def download_test_data():
|
|
"""Download test audio file from silero-vad repository."""
|
|
|
|
artifacts_dir = Path("artifacts")
|
|
artifacts_dir.mkdir(exist_ok=True)
|
|
|
|
test_wav_path = artifacts_dir / "test.wav"
|
|
|
|
if test_wav_path.exists():
|
|
print(f"✓ Test audio already exists at {test_wav_path}")
|
|
return test_wav_path
|
|
|
|
# Download the test.wav file from silero-vad tests
|
|
test_url = "https://github.com/snakers4/silero-vad/raw/refs/heads/master/tests/data/test.wav"
|
|
|
|
print(f"Downloading test audio from:")
|
|
print(f" {test_url}")
|
|
print(f"Saving to: {test_wav_path}")
|
|
|
|
try:
|
|
urllib.request.urlretrieve(test_url, test_wav_path)
|
|
file_size = test_wav_path.stat().st_size / 1024
|
|
print(f"✓ Download complete! File size: {file_size:.1f} KB")
|
|
except Exception as e:
|
|
print(f"✗ Error downloading test audio: {e}")
|
|
raise
|
|
|
|
return test_wav_path
|
|
|
|
|
|
def load_wav(wav_path):
|
|
"""Load a WAV file and return audio samples as float32 array normalized to [-1, 1]."""
|
|
with wave.open(str(wav_path), 'rb') as wav_file:
|
|
n_channels = wav_file.getnchannels()
|
|
sample_width = wav_file.getsampwidth()
|
|
frame_rate = wav_file.getframerate()
|
|
n_frames = wav_file.getnframes()
|
|
|
|
# Read raw bytes
|
|
raw_data = wav_file.readframes(n_frames)
|
|
|
|
# Convert to numpy array based on sample width
|
|
if sample_width == 2: # 16-bit
|
|
samples = np.frombuffer(raw_data, dtype=np.int16)
|
|
# Normalize to [-1, 1]
|
|
samples = samples.astype(np.float32) / 32768.0
|
|
elif sample_width == 4: # 32-bit
|
|
samples = np.frombuffer(raw_data, dtype=np.int32)
|
|
samples = samples.astype(np.float32) / 2147483648.0
|
|
else:
|
|
raise ValueError(f"Unsupported sample width: {sample_width}")
|
|
|
|
# If stereo, convert to mono by averaging channels
|
|
if n_channels == 2:
|
|
samples = samples.reshape(-1, 2).mean(axis=1)
|
|
|
|
return samples, frame_rate
|
|
|
|
|
|
def generate_reference_outputs(model_path, test_wav_path, artifacts_dir):
|
|
"""Generate reference outputs using ONNX Runtime for testing."""
|
|
|
|
print(f" Loading test audio from {test_wav_path}...")
|
|
audio_samples, sample_rate = load_wav(test_wav_path)
|
|
print(f" Sample rate: {sample_rate} Hz")
|
|
print(f" Audio length: {len(audio_samples)} samples ({len(audio_samples)/sample_rate:.2f} seconds)")
|
|
|
|
# Create ONNX Runtime session
|
|
print(f" Creating ONNX Runtime session...")
|
|
session = ort.InferenceSession(str(model_path), providers=['CPUExecutionProvider'])
|
|
|
|
# Silero VAD parameters
|
|
# For 16kHz: chunk_size = 512 (32ms)
|
|
# For 8kHz: chunk_size = 256 (32ms)
|
|
chunk_size = 512 if sample_rate == 16000 else 256
|
|
batch_size = 1
|
|
|
|
# Initialize state
|
|
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
|
|
|
# Process audio in chunks and collect outputs
|
|
results = {
|
|
"sample_rate": sample_rate,
|
|
"chunk_size": chunk_size,
|
|
"audio_length_samples": len(audio_samples),
|
|
"test_cases": []
|
|
}
|
|
|
|
# Test Case 1: First few chunks from the beginning of the audio
|
|
print(f" Running inference on test chunks...")
|
|
num_test_chunks = min(10, len(audio_samples) // chunk_size)
|
|
|
|
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
|
for i in range(num_test_chunks):
|
|
start_idx = i * chunk_size
|
|
end_idx = start_idx + chunk_size
|
|
|
|
chunk = audio_samples[start_idx:end_idx]
|
|
if len(chunk) < chunk_size:
|
|
# Pad with zeros if needed
|
|
chunk = np.pad(chunk, (0, chunk_size - len(chunk)), mode='constant')
|
|
|
|
# Prepare inputs
|
|
input_tensor = chunk.reshape(1, -1).astype(np.float32)
|
|
sr_tensor = np.array(sample_rate, dtype=np.int64)
|
|
|
|
# Run inference
|
|
outputs = session.run(None, {
|
|
'input': input_tensor,
|
|
'sr': sr_tensor,
|
|
'state': state
|
|
})
|
|
|
|
output_prob = float(outputs[0].flatten()[0])
|
|
state = outputs[1]
|
|
|
|
results["test_cases"].append({
|
|
"test_name": f"chunk_{i}",
|
|
"chunk_index": i,
|
|
"start_sample": start_idx,
|
|
"input_samples": chunk.tolist(),
|
|
"expected_output": output_prob,
|
|
"state_after": state.tolist()
|
|
})
|
|
|
|
# Test Case 2: Random input (for reproducibility test)
|
|
print(f" Generating random input test case...")
|
|
np.random.seed(42)
|
|
random_input = np.random.randn(chunk_size).astype(np.float32) * 0.1
|
|
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
|
|
|
outputs = session.run(None, {
|
|
'input': random_input.reshape(1, -1),
|
|
'sr': np.array(sample_rate, dtype=np.int64),
|
|
'state': state
|
|
})
|
|
|
|
results["test_cases"].append({
|
|
"test_name": "random_seed_42",
|
|
"chunk_index": -1,
|
|
"start_sample": -1,
|
|
"input_samples": random_input.tolist(),
|
|
"expected_output": float(outputs[0].flatten()[0]),
|
|
"state_after": outputs[1].tolist()
|
|
})
|
|
|
|
# Test Case 3: Zero input (silence)
|
|
print(f" Generating silence test case...")
|
|
zero_input = np.zeros(chunk_size, dtype=np.float32)
|
|
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
|
|
|
outputs = session.run(None, {
|
|
'input': zero_input.reshape(1, -1),
|
|
'sr': np.array(sample_rate, dtype=np.int64),
|
|
'state': state
|
|
})
|
|
|
|
results["test_cases"].append({
|
|
"test_name": "silence",
|
|
"chunk_index": -1,
|
|
"start_sample": -1,
|
|
"input_samples": zero_input.tolist(),
|
|
"expected_output": float(outputs[0].flatten()[0]),
|
|
"state_after": outputs[1].tolist()
|
|
})
|
|
|
|
# Save results
|
|
output_path = artifacts_dir / "reference_outputs.json"
|
|
with open(output_path, 'w') as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
print(f" ✓ Reference outputs saved to {output_path}")
|
|
print(f" Total test cases: {len(results['test_cases'])}")
|
|
|
|
|
|
def download_model():
|
|
"""Download the Silero VAD ONNX model (opset 18, if-less version)."""
|
|
|
|
# Create artifacts directory if it doesn't exist
|
|
artifacts_dir = Path("artifacts")
|
|
artifacts_dir.mkdir(exist_ok=True)
|
|
|
|
model_path = artifacts_dir / "silero_vad.onnx"
|
|
|
|
model_existed = model_path.exists()
|
|
|
|
# Skip download if model already exists
|
|
if model_existed:
|
|
print(f"✓ Model already exists at {model_path}")
|
|
print(f" File size: {model_path.stat().st_size / 1024:.1f} KB")
|
|
print()
|
|
else:
|
|
# Download the opset 18 if-less model
|
|
# Note: This model has external data that is missing from the repo
|
|
model_url = "https://github.com/snakers4/silero-vad/raw/refs/heads/master/src/silero_vad/data/silero_vad_op18_ifless.onnx"
|
|
|
|
print(f"Downloading Silero VAD model (opset 18, if-less) from:")
|
|
print(f" {model_url}")
|
|
print(f"Saving to: {model_path}")
|
|
print()
|
|
|
|
try:
|
|
urllib.request.urlretrieve(model_url, model_path)
|
|
file_size = model_path.stat().st_size / 1024
|
|
print(f"✓ Download complete! File size: {file_size:.1f} KB")
|
|
print()
|
|
except Exception as e:
|
|
print(f"✗ Error downloading model: {e}")
|
|
raise
|
|
|
|
# Extract node information from the ONNX model
|
|
try:
|
|
extract_node_info(model_path, artifacts_dir)
|
|
except Exception as e:
|
|
print(f"✗ Error extracting node information: {e}")
|
|
raise
|
|
|
|
print()
|
|
|
|
# Download test data
|
|
print("Downloading test data...")
|
|
try:
|
|
test_wav_path = download_test_data()
|
|
except Exception as e:
|
|
print(f"✗ Error downloading test data: {e}")
|
|
raise
|
|
|
|
print()
|
|
|
|
# Generate reference outputs
|
|
print("Generating reference outputs...")
|
|
try:
|
|
generate_reference_outputs(model_path, test_wav_path, artifacts_dir)
|
|
except Exception as e:
|
|
print(f"✗ Error generating reference outputs: {e}")
|
|
raise
|
|
|
|
print()
|
|
print("="*80)
|
|
print("Model preparation complete!")
|
|
print("="*80)
|
|
print()
|
|
print("Silero VAD (opset 18, if-less) has only 1 If node for sample rate selection.")
|
|
print("This makes it compatible with burn-onnx's static type inference.")
|
|
print()
|
|
print("See: https://github.com/snakers4/silero-vad/issues/728")
|
|
print()
|
|
print("Generated files:")
|
|
print(f" - {model_path} (ONNX model)")
|
|
print(f" - {artifacts_dir / 'node_info.json'} (node analysis)")
|
|
print(f" - {test_wav_path} (test audio)")
|
|
print(f" - {artifacts_dir / 'reference_outputs.json'} (reference test outputs)")
|
|
print()
|
|
print("Next steps:")
|
|
print(" 1. Build the model: cargo build")
|
|
print(" 2. Run the test: cargo run")
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
download_model()
|