mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
Move ONNX crates to burn-onnx repository (#4393)
* Move ONNX inference MNIST example to burn-onnx Moving to https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference * Move Raspberry Pi Pico example project to burn-onnx Moving to https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico * Move image-classification-web example to burn-onnx Moving to https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web * Refactor import-model-weights example to use burn-store Replace burn-import with burn-store for loading PyTorch and Safetensors model weights. Convert from NamedMpk (.mpk) format to Burnpack (.bpk) format for native Burn model storage. - Use PytorchStore and SafetensorsStore for weight loading - Use BurnpackStore for saving/loading converted models - Rename namedmpk binary to burnpack - Add PyTorchToBurnAdapter for Safetensors files exported from PyTorch - Update inference to accept Model directly instead of ModelRecord * Remove ONNX to Burn development guide from contributor-book The guide has been moved to the burn-onnx repository: https://github.com/tracel-ai/burn-onnx/blob/main/DEVELOPMENT-GUIDE.md * Update Cargo.lock * Move pytorch-tests and safetensors-tests from burn-import to burn-store Migrate test directories from crates/burn-import/ to crates/burn-store/ and update all tests to use the new burn-store API. Changes: - Update Cargo.toml dependencies from burn-import to burn-store - Replace PyTorchFileRecorder/SafetensorsFileRecorder with PytorchStore/SafetensorsStore - Convert record-based loading to direct model loading via ModuleSnapshot::load_from() - Convert LoadArgs options to fluent builder pattern (.with_key_remapping(), etc.) - Fix model configurations to match actual PyTorch weight file dimensions - Add init() methods for models that previously only had new_with(record) All 37 pytorch tests and 1 safetensors test pass. * Fix test file paths after moving test directories to burn-store Update paths in burn-store/src tests to reference the new locations of pytorch-tests and safetensors-tests directories. * Add migration guide for burn-import to burn-store Document migration path from deprecated PyTorchFileRecorder and SafetensorsFileRecorder to the new PytorchStore and SafetensorsStore APIs. Cover API mapping, code examples, and common migration issues. Include details on printing LoadResult for debugging and helpful suggestions. * Remove burn-import crate (moved to burn-onnx repo) The burn-import crate has been moved to a separate repository: https://github.com/tracel-ai/burn-onnx/tree/main/crates/burn-import For loading PyTorch and SafeTensors model weights, use burn-store instead with PytorchStore and SafetensorsStore. * Replace pytorch/safetensors docs with unified model-weights page Consolidate PyTorch and SafeTensors model import documentation into a single comprehensive page covering burn-store usage. The new page covers: - All supported formats (Burnpack, SafeTensors, PyTorch) - Loading and saving workflows - Advanced features (filtering, remapping, partial loading, zero-copy) - API reference and troubleshooting * Simplify burn-store README and point to Burn Book * Consolidate model weights docs into saving-and-loading page Merge the model-weights documentation into the main saving-and-loading page, providing a single comprehensive guide for all model persistence operations using burn-store. - Remove separate import/model-weights.md page - Update saving-and-loading.md with full burn-store documentation - Remove unused SVG images from import folder - Update navigation in SUMMARY.md and import/README.md * Move ONNX import to standalone section, point to burn-onnx repo - Create new onnx-import.md as standalone top-level section - Update links to point to burn-onnx repo (github.com/tracel-ai/burn-onnx) - Remove import/ folder (no longer needed) - Streamline documentation with quick start focus * Restore full ONNX import documentation content Restored detailed content including: - Understanding ONNX section - Burn's ONNX Support advantages - ONNX compatibility and opset version guidance - Step-by-step guide with code examples - Advanced configuration options - Troubleshooting section - Examples and resources with links to burn-onnx repo * Update ONNX examples links to burn-onnx repo * Update ONNX-related example links in examples.md to burn-onnx repo * Fix Burnpack extension: .burnpack -> .bpk * Remove ONNX Tests README reference from docs * Improve table formatting in ONNX import and saving docs Reformats markdown tables in onnx-import.md and saving-and-loading.md for better readability and consistency. No content changes, only improved alignment and line breaks. * Remove onnx-ir and onnx-ir-derive crates Moved to https://github.com/tracel-ai/burn-onnx * Remove burn-onnx crate Moved to https://github.com/tracel-ai/burn-onnx * Remove ONNX references from workspace config and README - Remove burn-onnx entries from Cargo.toml workspace members/exclude - Remove RUSTSEC-2024-0437 (protobuf) from audit.toml - Update README ONNX section to point to burn-onnx repo - Update README model loading section to point to new docs - Remove ONNX example from README (moved to burn-onnx) * Clean up remaining ONNX references - Remove ONNX publish jobs from publish.yml - Update semver-checks exclude list - Update burn-book overview.md links - Update burn-book no-std.md links to burn-onnx repo - Remove ONNX proto license from NOTICES.md - Update .gitignore comment * Remove unused protobuf and rust-format dependencies * Remove onnx-tests example from contributor-book * Update Cargo.lock * Format long line in test for readability Split a long line in the should_fail_if_struct_field_is_missing test to improve code readability. No functional changes were made. * Restore Record-based saving/loading sections per PR feedback * Explain burn-store motivation in saving-and-loading docs * Remove em dash from burn-store intro * Format long line in test for readability
This commit is contained in:
committed by
GitHub
parent
1cda8e14f0
commit
933fdf4f69
@@ -8,7 +8,6 @@
|
||||
|
||||
[advisories]
|
||||
ignore = [
|
||||
"RUSTSEC-2024-0437", # Protobuf used in ONNX graph parsing.
|
||||
"RUSTSEC-2024-0436", # Paste used to generate macro, should be removed at some point.
|
||||
"RUSTSEC-2025-0119", # `number_prefix` used by `tokenizers`, only in the examples.
|
||||
"RUSTSEC-2025-0141", # `bincode` is no longer maintained.
|
||||
|
||||
43
.github/workflows/publish.yml
vendored
43
.github/workflows/publish.yml
vendored
@@ -386,49 +386,6 @@ jobs:
|
||||
secrets:
|
||||
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
||||
|
||||
publish-burn-onnx:
|
||||
needs:
|
||||
- publish-burn
|
||||
- publish-onnx-ir
|
||||
- publish-burn-store
|
||||
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v7
|
||||
with:
|
||||
crate: burn-onnx
|
||||
dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
|
||||
secrets:
|
||||
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
||||
|
||||
publish-burn-import:
|
||||
needs:
|
||||
- publish-burn
|
||||
- publish-burn-store
|
||||
- publish-burn-onnx
|
||||
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v7
|
||||
with:
|
||||
crate: burn-import
|
||||
dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
|
||||
secrets:
|
||||
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
||||
|
||||
publish-onnx-ir-derive:
|
||||
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v7
|
||||
with:
|
||||
crate: onnx-ir-derive
|
||||
dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
|
||||
secrets:
|
||||
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
||||
|
||||
publish-onnx-ir:
|
||||
needs:
|
||||
- publish-burn-tensor
|
||||
- publish-onnx-ir-derive
|
||||
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v7
|
||||
with:
|
||||
crate: onnx-ir
|
||||
dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
|
||||
secrets:
|
||||
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
|
||||
|
||||
publish-burn-store:
|
||||
needs:
|
||||
- publish-burn-core
|
||||
|
||||
2
.github/workflows/semver-checks.yml
vendored
2
.github/workflows/semver-checks.yml
vendored
@@ -26,4 +26,4 @@ jobs:
|
||||
# publishes on crates.io with `default-features`
|
||||
feature-group: default-features
|
||||
# Exclude crates which are not published on crates.io
|
||||
exclude: burn-no-std-tests,onnx-tests,pytorch-tests
|
||||
exclude: burn-no-std-tests,pytorch-tests,safetensors-tests
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -10,7 +10,7 @@ target
|
||||
.fleet
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Generated IR and Burn Graph from ONNX
|
||||
# Build output directory
|
||||
out
|
||||
|
||||
# Virtual Environment of Python
|
||||
|
||||
300
Cargo.lock
generated
300
Cargo.lock
generated
@@ -463,7 +463,7 @@ dependencies = [
|
||||
"clang-sys",
|
||||
"itertools 0.13.0",
|
||||
"log",
|
||||
"prettyplease 0.2.37",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
@@ -610,7 +610,7 @@ checksum = "89ec27229c38ed0eb3c0feee3d2c1d6a4379ae44f418a29a658890e062d8f365"
|
||||
dependencies = [
|
||||
"darling 0.23.0",
|
||||
"ident_case",
|
||||
"prettyplease 0.2.37",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
@@ -952,23 +952,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-import"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-ndarray",
|
||||
"burn-onnx",
|
||||
"burn-store",
|
||||
"candle-core",
|
||||
"derive-new",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
"zip 7.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-ir"
|
||||
version = "0.21.0"
|
||||
@@ -1033,25 +1016,6 @@ dependencies = [
|
||||
"burn-store",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-onnx"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-ndarray",
|
||||
"burn-store",
|
||||
"derive-new",
|
||||
"insta",
|
||||
"log",
|
||||
"onnx-ir",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rust-format",
|
||||
"syn 2.0.114",
|
||||
"tracing-core",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "burn-optim"
|
||||
version = "0.21.0"
|
||||
@@ -1974,7 +1938,7 @@ dependencies = [
|
||||
"document-features",
|
||||
"mio",
|
||||
"parking_lot",
|
||||
"rustix 1.1.3",
|
||||
"rustix",
|
||||
"signal-hook 0.3.18",
|
||||
"signal-hook-mio",
|
||||
"winapi",
|
||||
@@ -2240,7 +2204,7 @@ dependencies = [
|
||||
"darling 0.21.3",
|
||||
"derive-new",
|
||||
"ident_case",
|
||||
"prettyplease 0.2.37",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.114",
|
||||
@@ -2811,12 +2775,6 @@ version = "1.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "abd57806937c9cc163efc8ea3910e00a62e2aeb0b8119f1793a978088f8f6b04"
|
||||
|
||||
[[package]]
|
||||
name = "diff"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
@@ -3475,7 +3433,7 @@ version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4"
|
||||
dependencies = [
|
||||
"rustix 1.1.3",
|
||||
"rustix",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -4513,27 +4471,6 @@ dependencies = [
|
||||
"zune-jpeg 0.5.8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "image-classification-web"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-candle",
|
||||
"burn-onnx",
|
||||
"burn-store",
|
||||
"console_error_panic_hook",
|
||||
"getrandom 0.3.4",
|
||||
"js-sys",
|
||||
"log",
|
||||
"serde",
|
||||
"serde-wasm-bindgen",
|
||||
"serde_json",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-logger",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "image-webp"
|
||||
version = "0.2.4"
|
||||
@@ -4555,7 +4492,7 @@ name = "import-model-weights"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-import",
|
||||
"burn-store",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4614,18 +4551,6 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "insta"
|
||||
version = "1.46.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b66886d14d18d420ab5052cbff544fc5d34d0b2cdd35eb5976aaa10a4a472e5"
|
||||
dependencies = [
|
||||
"console 0.15.11",
|
||||
"once_cell",
|
||||
"similar",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instability"
|
||||
version = "0.3.11"
|
||||
@@ -4895,12 +4820,6 @@ dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
@@ -5803,59 +5722,6 @@ dependencies = [
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnx-inference"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-onnx",
|
||||
"burn-store",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnx-ir"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn-tensor",
|
||||
"bytemuck",
|
||||
"bytes",
|
||||
"derive-new",
|
||||
"divan",
|
||||
"half",
|
||||
"log",
|
||||
"memmap2",
|
||||
"onnx-ir-derive",
|
||||
"pretty_assertions",
|
||||
"protobuf",
|
||||
"protobuf-codegen",
|
||||
"regex",
|
||||
"rstest",
|
||||
"serde",
|
||||
"strum",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnx-ir-derive"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnx-tests"
|
||||
version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-onnx",
|
||||
"burn-store",
|
||||
"float-cmp",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openblas-build"
|
||||
version = "0.10.14"
|
||||
@@ -6884,26 +6750,6 @@ version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa"
|
||||
|
||||
[[package]]
|
||||
name = "pretty_assertions"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
|
||||
dependencies = [
|
||||
"diff",
|
||||
"yansi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.1.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.37"
|
||||
@@ -6974,58 +6820,6 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "3.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d65a1d4ddae7d8b5de68153b48f6aa3bba8cb002b243dbdbc55a5afbc98f99f4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"once_cell",
|
||||
"protobuf-support",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf-codegen"
|
||||
version = "3.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d3976825c0014bbd2f3b34f0001876604fe87e0c86cd8fa54251530f1544ace"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"once_cell",
|
||||
"protobuf",
|
||||
"protobuf-parse",
|
||||
"regex",
|
||||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf-parse"
|
||||
version = "3.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4aeaa1f2460f1d348eeaeed86aea999ce98c1bded6f089ff8514c9d9dbdc973"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"indexmap",
|
||||
"log",
|
||||
"protobuf",
|
||||
"protobuf-support",
|
||||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf-support"
|
||||
version = "3.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3e36c2f31e0a47f9280fb347ef5e461ffcd2c52dd520d8e216b52f93b0b0d7d6"
|
||||
dependencies = [
|
||||
"thiserror 1.0.69",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psm"
|
||||
version = "0.1.28"
|
||||
@@ -7077,8 +6871,8 @@ version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-autodiff",
|
||||
"burn-import",
|
||||
"burn-ndarray",
|
||||
"burn-store",
|
||||
"float-cmp",
|
||||
"serde",
|
||||
]
|
||||
@@ -7708,17 +7502,6 @@ dependencies = [
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-format"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60e7c00b6c3bf5e38a880eec01d7e829d12ca682079f8238a464def3c4b31627"
|
||||
dependencies = [
|
||||
"prettyplease 0.1.25",
|
||||
"proc-macro2",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.26"
|
||||
@@ -7746,19 +7529,6 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.3"
|
||||
@@ -7768,7 +7538,7 @@ dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.11.0",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
@@ -7869,8 +7639,8 @@ version = "0.21.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-autodiff",
|
||||
"burn-import",
|
||||
"burn-ndarray",
|
||||
"burn-store",
|
||||
"float-cmp",
|
||||
"serde",
|
||||
]
|
||||
@@ -7990,17 +7760,6 @@ dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde-wasm-bindgen"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_bytes"
|
||||
version = "0.11.19"
|
||||
@@ -8231,12 +7990,6 @@ version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
|
||||
|
||||
[[package]]
|
||||
name = "similar"
|
||||
version = "2.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
|
||||
|
||||
[[package]]
|
||||
name = "simple-regression"
|
||||
version = "0.21.0"
|
||||
@@ -8579,7 +8332,7 @@ dependencies = [
|
||||
"fastrand",
|
||||
"getrandom 0.3.4",
|
||||
"once_cell",
|
||||
"rustix 1.1.3",
|
||||
"rustix",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
@@ -8598,7 +8351,7 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60b8cb979cb11c32ce1603f8137b22262a9d131aaa5c37b5678025f22b8becd0"
|
||||
dependencies = [
|
||||
"rustix 1.1.3",
|
||||
"rustix",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
@@ -9818,17 +9571,6 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-logger"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "074649a66bb306c8f2068c9016395fa65d8e08d2affcbf95acf3c24c3ab19718"
|
||||
dependencies = [
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-streams"
|
||||
version = "0.4.2"
|
||||
@@ -10125,18 +9867,6 @@ dependencies = [
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
||||
dependencies = [
|
||||
"either",
|
||||
"home",
|
||||
"once_cell",
|
||||
"rustix 0.38.44",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "widestring"
|
||||
version = "1.2.1"
|
||||
@@ -10489,7 +10219,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rustix 1.1.3",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -10520,12 +10250,6 @@ version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a5a4b21e1a62b67a2970e6831bc091d7b87e119e7f9791aef9702e3bef04448"
|
||||
|
||||
[[package]]
|
||||
name = "yansi"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.7.5"
|
||||
|
||||
@@ -6,9 +6,8 @@ resolver = "2"
|
||||
|
||||
members = [
|
||||
"crates/*",
|
||||
"crates/burn-import/pytorch-tests",
|
||||
"crates/burn-import/safetensors-tests",
|
||||
"crates/burn-onnx/onnx-tests",
|
||||
"crates/burn-store/pytorch-tests",
|
||||
"crates/burn-store/safetensors-tests",
|
||||
"crates/burn-collective/multinode-tests",
|
||||
"examples/*",
|
||||
"xtask",
|
||||
@@ -17,7 +16,6 @@ members = [
|
||||
exclude = [
|
||||
"examples/notebook",
|
||||
"examples/raspberry-pi-pico",
|
||||
"crates/burn-onnx/model-checks/*", # exclude model checking crates from workspace
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -75,8 +73,6 @@ planus = { version = "=1.1" }
|
||||
polars = { version = "0.51.0", features = ["lazy"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
proc-macro2 = "1.0.106"
|
||||
protobuf = "3.7.2"
|
||||
protobuf-codegen = "3.7.2"
|
||||
quote = "1.0.42"
|
||||
r2d2 = "0.8.10"
|
||||
r2d2_sqlite = "0.31.0"
|
||||
@@ -91,7 +87,6 @@ reqwest = { version = "0.12.23", default-features = false, features = [
|
||||
rmp-serde = { version = "1.3.1", default-features = false }
|
||||
rstest = "0.26.1"
|
||||
rusqlite = "0.37.0"
|
||||
rust-format = "0.3.4"
|
||||
sanitize-filename = "0.6.0"
|
||||
serde_bytes = { version = "0.11.18", default-features = false, features = [
|
||||
"alloc",
|
||||
|
||||
208
NOTICES.md
208
NOTICES.md
@@ -38,214 +38,6 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
## ONNX
|
||||
|
||||
**Source**: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
|
||||
|
||||
License: Apache License 2.0
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
## wgpu
|
||||
|
||||
**Source:** https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
|
||||
|
||||
21
README.md
21
README.md
@@ -251,16 +251,17 @@ ONNX Support 🐫
|
||||
</summary>
|
||||
<br />
|
||||
|
||||
Burn supports importing ONNX (Open Neural Network Exchange) models, allowing you to easily port
|
||||
models from TensorFlow or PyTorch to Burn. The ONNX model is converted into Rust code that uses
|
||||
Burn's native APIs, enabling the imported model to run on any Burn backend (CPU, GPU, WebAssembly)
|
||||
and benefit from all of Burn's optimizations like automatic kernel fusion.
|
||||
Burn supports importing ONNX (Open Neural Network Exchange) models through the
|
||||
[burn-onnx](https://github.com/tracel-ai/burn-onnx) crate, allowing you to easily port models from
|
||||
TensorFlow or PyTorch to Burn. The ONNX model is converted into Rust code that uses Burn's native
|
||||
APIs, enabling the imported model to run on any Burn backend (CPU, GPU, WebAssembly) and benefit
|
||||
from all of Burn's optimizations like automatic kernel fusion.
|
||||
|
||||
Our ONNX support is further described in
|
||||
[this section of the Burn Book 🔥](https://burn.dev/books/burn/import/onnx-model.html).
|
||||
[this section of the Burn Book 🔥](https://burn.dev/books/burn/onnx-import.html).
|
||||
|
||||
> **Note**: This crate is in active development and currently supports a
|
||||
> [limited set of ONNX operators](./crates/burn-onnx/SUPPORTED-ONNX-OPS.md).
|
||||
> [limited set of ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md).
|
||||
|
||||
</details>
|
||||
|
||||
@@ -274,10 +275,8 @@ You can load weights from PyTorch or Safetensors formats directly into your Burn
|
||||
This makes it easy to reuse existing models while benefiting from Burn's performance and deployment
|
||||
features.
|
||||
|
||||
Learn more:
|
||||
|
||||
- [Import pre-trained PyTorch models into Burn](https://burn.dev/books/burn/import/pytorch-model.html)
|
||||
- [Load models from Safetensors format](https://burn.dev/books/burn/import/safetensors-model.html)
|
||||
Learn more in the [Saving & Loading Models](https://burn.dev/books/burn/saving-and-loading.html)
|
||||
section of the Burn Book.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -423,8 +422,6 @@ Additional examples:
|
||||
`Learner` configured to log metrics and keep training checkpoints.
|
||||
- [Named Tensor](./examples/named-tensor) : Performs operations with the experimental `NamedTensor`
|
||||
feature.
|
||||
- [ONNX Import Inference](./examples/onnx-inference) : Imports an ONNX model pre-trained on MNIST to
|
||||
perform inference on a sample image with Burn.
|
||||
- [PyTorch Import Inference](./examples/import-model-weights) : Imports a PyTorch model pre-trained
|
||||
on MNIST to perform inference on a sample image with Burn.
|
||||
- [Text Classification](./examples/text-classification) : Trains a text classification transformer
|
||||
|
||||
@@ -27,10 +27,7 @@
|
||||
- [Distributed Computing](./performance/distributed-computing.md)
|
||||
- [Custom Training Loop](./custom-training-loop.md)
|
||||
- [Saving & Loading Models](./saving-and-loading.md)
|
||||
- [Importing Models](./import/README.md)
|
||||
- [ONNX Model](./import/onnx-model.md)
|
||||
- [PyTorch Model](./import/pytorch-model.md)
|
||||
- [Safetensors Model](./import/safetensors-model.md)
|
||||
- [ONNX Import](./onnx-import.md)
|
||||
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)
|
||||
- [Advanced](./advanced/README.md)
|
||||
- [Backend Extension](./advanced/backend-extension/README.md)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# No Standard Library
|
||||
|
||||
In this section, you will learn how to run an onnx inference model on an embedded system, with no standard library support on a Raspberry Pi Pico 2. This should be universally applicable to other platforms. All the code can be found under the
|
||||
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/raspberry-pi-pico).
|
||||
In this section, you will learn how to run an ONNX inference model on an embedded system, with no
|
||||
standard library support on a Raspberry Pi Pico 2. This should be universally applicable to other
|
||||
platforms. All the code can be found in the
|
||||
[burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico).
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
@@ -31,7 +33,7 @@ burn-onnx = { version = "0.21" } # Used to auto generate the rust code to import
|
||||
```
|
||||
|
||||
### Import the Model
|
||||
Follow the directions to [import models](../import/README.md).
|
||||
Follow the directions in [ONNX Import](../onnx-import.md).
|
||||
|
||||
Use the following ModelGen config
|
||||
```rs
|
||||
|
||||
@@ -77,11 +77,11 @@ The following additional examples are currently available if you want to check t
|
||||
| [Regression](https://github.com/tracel-ai/burn/tree/main/examples/simple-regression) | Trains a simple MLP on the California Housing dataset to predict the median house value for a district. |
|
||||
| [Custom Image Dataset](https://github.com/tracel-ai/burn/tree/main/examples/custom-image-dataset) | Trains a simple CNN on custom image dataset following a simple folder structure. |
|
||||
| [Custom Renderer](https://github.com/tracel-ai/burn/tree/main/examples/custom-renderer) | Implements a custom renderer to display the [`Learner`](./building-blocks/learner.md) progress. |
|
||||
| [Image Classification Web](https://github.com/tracel-ai/burn/tree/main/examples/image-classification-web) | Image classification web browser demo using Burn, WGPU and WebAssembly. |
|
||||
| [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) | Image classification web browser demo using Burn, WGPU and WebAssembly. |
|
||||
| [MNIST Inference on Web](https://github.com/tracel-ai/burn/tree/main/examples/mnist-inference-web) | An interactive MNIST inference demo in the browser. The demo is available [online](https://burn.dev/demo/). |
|
||||
| [MNIST Training](https://github.com/tracel-ai/burn/tree/main/examples/mnist) | Demonstrates how to train a custom [`Module`](./building-blocks/module.md) (MLP) with the [`Learner`](./building-blocks/learner.md) configured to log metrics and keep training checkpoints. |
|
||||
| [Named Tensor](https://github.com/tracel-ai/burn/tree/main/examples/named-tensor) | Performs operations with the experimental `NamedTensor` feature. |
|
||||
| [ONNX Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference) | Imports an ONNX model pre-trained on MNIST to perform inference on a sample image with Burn. |
|
||||
| [ONNX Import Inference](https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference) | Imports an ONNX model pre-trained on MNIST to perform inference on a sample image with Burn. |
|
||||
| [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/import-model-weights) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. |
|
||||
| [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. |
|
||||
| [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. |
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Importing Models
|
||||
|
||||
Burn supports importing models from other frameworks and file formats, enabling you to use pre-trained weights in your Burn applications.
|
||||
|
||||
## Supported Formats
|
||||
|
||||
Burn currently supports three primary model import formats:
|
||||
|
||||
| Format | Description | Use Case |
|
||||
|--------|-------------|----------|
|
||||
| [**ONNX**](./onnx-model.md) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export |
|
||||
| [**PyTorch**](./pytorch-model.md) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture |
|
||||
| [**Safetensors**](./safetensors-model.md) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture |
|
||||
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" id="export" class="canvas" preserveAspectRatio="xMidYMid meet" style="" width="112" height="124"><rect id="background" fill="#fff" pointer-events="all" width="112" height="124"/><g id="origin" transform="translate(5.077343750000001, 5.077343750000001) scale(1)"><g id="clusters" class="clusters"/><g id="edge-paths" class="edge-paths"><defs><marker id="arrowhead" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto" style="fill: rgb(0, 0, 0);"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker><marker id="arrowhead-select" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto" style="fill: rgb(238, 0, 0);"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker><marker id="arrowhead-hover" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker></defs></g><g id="edge-labels" class="edge-labels"/><g id="nodes" class="nodes"><g id="node-name-conv1" class="node graph-node" transform="translate(0,60)" style=""><g class="node-item node-item-type" transform="translate(0,0)"><path d="M5,0h91.546875a5,5 0 0 1 5,5v16a0,0 0 0 1 0,0h-101.546875a0,0 0 0 1 0,0v-16a5,5 0 0 1 5,-5z" style="stroke: rgb(0, 0, 0); fill: rgb(0, 0, 0); stroke-width: 0;"/><text x="6" y="15" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 11px; text-rendering: geometricprecision; user-select: none; fill: rgb(255, 255, 255);">conv1</text><title>?</title></g><g class="node-attribute-list" transform="translate(0,21)"><path d="M0,0h101.546875a0,0 0 0 1 0,0v27a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-27a0,0 0 0 1 0,0z" style="stroke: rgb(0, 0, 0); fill: rgb(255, 255, 255); stroke-width: 0;"/><g class="node-attribute"><text xml:space="preserve" x="6" y="13" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2,2,2,2]</title><tspan style="font-weight: bold;">weight</tspan><tspan>〈2×2×2×2〉</tspan></text></g><g class="node-attribute"><text xml:space="preserve" x="6" y="26" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2]</title><tspan style="font-weight: bold;">bias</tspan><tspan>〈2〉</tspan></text></g><line class="node" x1="0" x2="101.546875" y1="0" y2="0" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><path class="node node-border" d="M5,0h91.546875a5,5 0 0 1 5,5v43a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-43a5,5 0 0 1 5,-5z" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><g id="node-name-conv2" class="node graph-node" transform="translate(0,0)" style=""><g class="node-item node-item-type" transform="translate(0,0)"><path d="M5,0h91.546875a5,5 0 0 1 5,5v16a0,0 0 0 1 0,0h-101.546875a0,0 0 0 1 0,0v-16a5,5 0 0 1 5,-5z" style="stroke: rgb(0, 0, 0); fill: rgb(0, 0, 0); stroke-width: 0;"/><text x="6" y="15" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 11px; text-rendering: geometricprecision; user-select: none; fill: rgb(255, 255, 255);">conv2</text><title>?</title></g><g class="node-attribute-list" transform="translate(0,21)"><path d="M0,0h101.546875a0,0 0 0 1 0,0v14a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-14a0,0 0 0 1 0,0z" style="stroke: rgb(0, 0, 0); fill: rgb(255, 255, 255); stroke-width: 0;"/><g class="node-attribute"><text xml:space="preserve" x="6" y="13" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2,2,2,2]</title><tspan style="font-weight: bold;">weight</tspan><tspan>〈2×2×2×2〉</tspan></text></g><line class="node" x1="0" x2="101.546875" y1="0" y2="0" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><path class="node node-border" d="M5,0h91.546875a5,5 0 0 1 5,5v30a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-30a5,5 0 0 1 5,-5z" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g></g></g></svg>
|
||||
|
Before Width: | Height: | Size: 4.8 KiB |
@@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" id="export" class="canvas" preserveAspectRatio="xMidYMid meet" style="" width="112" height="124"><rect id="background" fill="#fff" pointer-events="all" width="112" height="124"/><g id="origin" transform="translate(5.077343750000001, 5.077343750000001) scale(1)"><g id="clusters" class="clusters"/><g id="edge-paths" class="edge-paths"><defs><marker id="arrowhead" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto" style="fill: rgb(0, 0, 0);"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker><marker id="arrowhead-select" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto" style="fill: rgb(238, 0, 0);"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker><marker id="arrowhead-hover" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker></defs></g><g id="edge-labels" class="edge-labels"/><g id="nodes" class="nodes"><g id="node-name-conv.conv1" class="node graph-node" transform="translate(0,60)" style=""><g class="node-item node-item-type" transform="translate(0,0)"><path d="M5,0h91.546875a5,5 0 0 1 5,5v16a0,0 0 0 1 0,0h-101.546875a0,0 0 0 1 0,0v-16a5,5 0 0 1 5,-5z" style="stroke: rgb(0, 0, 0); fill: rgb(0, 0, 0); stroke-width: 0;"/><text x="6" y="15" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 11px; text-rendering: geometricprecision; user-select: none; fill: rgb(255, 255, 255);">conv.conv1</text><title>?</title></g><g class="node-attribute-list" transform="translate(0,21)"><path d="M0,0h101.546875a0,0 0 0 1 0,0v27a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-27a0,0 0 0 1 0,0z" style="stroke: rgb(0, 0, 0); fill: rgb(255, 255, 255); stroke-width: 0;"/><g class="node-attribute"><text xml:space="preserve" x="6" y="13" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2,2,2,2]</title><tspan style="font-weight: bold;">weight</tspan><tspan>〈2×2×2×2〉</tspan></text></g><g class="node-attribute"><text xml:space="preserve" x="6" y="26" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2]</title><tspan style="font-weight: bold;">bias</tspan><tspan>〈2〉</tspan></text></g><line class="node" x1="0" x2="101.546875" y1="0" y2="0" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><path class="node node-border" d="M5,0h91.546875a5,5 0 0 1 5,5v43a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-43a5,5 0 0 1 5,-5z" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><g id="node-name-conv.conv2" class="node graph-node" transform="translate(0,0)" style=""><g class="node-item node-item-type" transform="translate(0,0)"><path d="M5,0h91.546875a5,5 0 0 1 5,5v16a0,0 0 0 1 0,0h-101.546875a0,0 0 0 1 0,0v-16a5,5 0 0 1 5,-5z" style="stroke: rgb(0, 0, 0); fill: rgb(0, 0, 0); stroke-width: 0;"/><text x="6" y="15" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 11px; text-rendering: geometricprecision; user-select: none; fill: rgb(255, 255, 255);">conv.conv2</text><title>?</title></g><g class="node-attribute-list" transform="translate(0,21)"><path d="M0,0h101.546875a0,0 0 0 1 0,0v14a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-14a0,0 0 0 1 0,0z" style="stroke: rgb(0, 0, 0); fill: rgb(255, 255, 255); stroke-width: 0;"/><g class="node-attribute"><text xml:space="preserve" x="6" y="13" style="font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", Ubuntu, "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2,2,2,2]</title><tspan style="font-weight: bold;">weight</tspan><tspan>〈2×2×2×2〉</tspan></text></g><line class="node" x1="0" x2="101.546875" y1="0" y2="0" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><path class="node node-border" d="M5,0h91.546875a5,5 0 0 1 5,5v30a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-30a5,5 0 0 1 5,-5z" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g></g></g></svg>
|
||||
|
Before Width: | Height: | Size: 4.8 KiB |
@@ -1,344 +0,0 @@
|
||||
# PyTorch Model
|
||||
|
||||
## Introduction
|
||||
|
||||
Burn supports importing model weights from PyTorch, whether you've trained your model in PyTorch or
|
||||
want to use a pre-trained model. Burn supports importing PyTorch model weights with `.pt` and
|
||||
`.safetensors` file extensions. Compared to ONNX models, these files only contain the weights of the
|
||||
model, so you will need to reconstruct the model architecture in Burn.
|
||||
|
||||
This guide demonstrates the complete workflow for exporting models from PyTorch and importing them
|
||||
into Burn. You can also refer to this
|
||||
[Transitioning From PyTorch to Burn](https://dev.to/laggui/transitioning-from-pytorch-to-burn-45m)
|
||||
tutorial for importing a more complex model.
|
||||
|
||||
## Exporting Models to PyTorch Format
|
||||
|
||||
To export a PyTorch model correctly, you need to save only the model weights (state_dict) using the
|
||||
`torch.save` function, not the entire model.
|
||||
|
||||
### Example: Exporting a PyTorch Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Initialize model and ensure it's on CPU
|
||||
model = Net().to(torch.device("cpu"))
|
||||
|
||||
# Extract model weights dictionary
|
||||
model_weights = model.state_dict()
|
||||
|
||||
# Save only the weights, not the entire model
|
||||
torch.save(model_weights, "conv2d.pt")
|
||||
```
|
||||
|
||||
If you accidentally save the entire model instead of just the weights, you may encounter errors
|
||||
during import like:
|
||||
|
||||
```
|
||||
Failed to decode foobar: DeserializeError("Serde error: other error:
|
||||
Missing source values for the 'foo1' field of type 'BarRecordItem'.
|
||||
Please verify the source data and ensure the field name is correct")
|
||||
```
|
||||
|
||||
### Verifying the Export
|
||||
|
||||
You can verify your exported model by viewing the `.pt` file in
|
||||
[Netron](https://github.com/lutzroeder/netron), a neural network visualization tool. A properly
|
||||
exported weights file will show a flat structure of tensors, while an incorrectly exported file will
|
||||
display nested blocks representing the entire model architecture.
|
||||
|
||||
When viewing the exported model in Netron, you should see something like this:
|
||||
|
||||

|
||||
|
||||
## Importing PyTorch Models into Burn
|
||||
|
||||
Importing a PyTorch model into Burn involves two main steps:
|
||||
|
||||
1. Defining the model architecture in Burn
|
||||
2. Loading the weights from the exported PyTorch model
|
||||
|
||||
### Step 1: Define the Model in Burn
|
||||
|
||||
First, you need to create a Burn model that matches the architecture of the model you exported:
|
||||
|
||||
```rust
|
||||
use burn::{
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
conv2: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = Conv2dConfig::new([2, 2], [2, 2])
|
||||
.init(device);
|
||||
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv1.forward(x);
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Load the Weights
|
||||
|
||||
You have two options for loading the weights:
|
||||
|
||||
#### Option A: Load Dynamically at Runtime
|
||||
|
||||
This approach loads the PyTorch file directly at runtime, requiring the `burn-import` dependency:
|
||||
|
||||
```rust
|
||||
use crate::model;
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
fn main() {
|
||||
let device = Default::default();
|
||||
|
||||
// Load weights from PyTorch file
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("./conv2d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
// Initialize model and load weights
|
||||
let model = model::Net::<Backend>::init(&device).load_record(record);
|
||||
}
|
||||
```
|
||||
|
||||
#### Option B: Pre-convert to Burn's Binary Format
|
||||
|
||||
This approach converts the PyTorch file to Burn's optimized binary format during build time,
|
||||
removing the runtime dependency on `burn-import`:
|
||||
|
||||
```rust
|
||||
// This code would go in build.rs or a separate tool
|
||||
|
||||
use crate::model;
|
||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
fn convert_model() {
|
||||
let device = Default::default();
|
||||
|
||||
// Load from PyTorch
|
||||
let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default();
|
||||
let record = recorder
|
||||
.load("./conv2d.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
// Save to Burn's binary format
|
||||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
|
||||
recorder
|
||||
.record(record, "model.mpk".into())
|
||||
.expect("Failed to save model record");
|
||||
}
|
||||
|
||||
// In your application code
|
||||
fn load_model() -> Net<Backend> {
|
||||
let device = Default::default();
|
||||
|
||||
// Load from Burn's binary format
|
||||
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("./model.mpk".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
Net::<Backend>::init(&device).load_record(record)
|
||||
}
|
||||
```
|
||||
|
||||
> **Note**: For examples of pre-converting models, see the `examples/import-model-weights` directory
|
||||
> in the Burn repository.
|
||||
|
||||
## Extract Configuration
|
||||
|
||||
In some cases, models may require additional configuration settings, which are often included in a
|
||||
`.pt` file during export. The `config_from_file` function from the `burn-import` cargo package
|
||||
allows for the extraction of these configurations directly from the `.pt` file.
|
||||
|
||||
```rust
|
||||
use std::collections::HashMap;
|
||||
|
||||
use burn::config::Config;
|
||||
use burn_import::pytorch::config_from_file;
|
||||
|
||||
#[derive(Debug, Config)]
|
||||
struct NetConfig {
|
||||
n_head: usize,
|
||||
n_layer: usize,
|
||||
d_model: usize,
|
||||
some_float: f64,
|
||||
some_int: i32,
|
||||
some_bool: bool,
|
||||
some_str: String,
|
||||
some_list_int: Vec<i32>,
|
||||
some_list_str: Vec<String>,
|
||||
some_list_float: Vec<f64>,
|
||||
some_dict: HashMap<String, String>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let path = "weights_with_config.pt";
|
||||
let top_level_key = Some("my_config");
|
||||
let config: NetConfig = config_from_file(path, top_level_key).unwrap();
|
||||
println!("{:#?}", config);
|
||||
|
||||
// After extracting, it's recommended you save it as a json file.
|
||||
config.save("my_config.json").unwrap();
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting and Advanced Features
|
||||
|
||||
### Key Remapping for Different Model Architectures
|
||||
|
||||
If your Burn model structure doesn't match the parameter names in the PyTorch file, you can remap
|
||||
keys using regular expressions:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
.with_key_remap("conv\\.(.*)", "$1");
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::init(&device).load_record(record);
|
||||
```
|
||||
|
||||
### Debugging with Key Inspection
|
||||
|
||||
To help with troubleshooting import issues, you can enable debugging to print the original and
|
||||
remapped keys:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
.with_key_remap("conv\\.(.*)", "$1")
|
||||
.with_debug_print(); // Print the keys and remapped keys
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::init(&device).load_record(record);
|
||||
```
|
||||
|
||||
Here is an example of the output:
|
||||
|
||||
```text
|
||||
Debug information of keys and tensor shapes:
|
||||
---
|
||||
Original Key: conv.conv1.bias
|
||||
Remapped Key: conv1.bias
|
||||
Shape: [2]
|
||||
Dtype: F32
|
||||
---
|
||||
Original Key: conv.conv1.weight
|
||||
Remapped Key: conv1.weight
|
||||
Shape: [2, 2, 2, 2]
|
||||
Dtype: F32
|
||||
---
|
||||
Original Key: conv.conv2.weight
|
||||
Remapped Key: conv2.weight
|
||||
Shape: [2, 2, 2, 2]
|
||||
Dtype: F32
|
||||
---
|
||||
```
|
||||
|
||||
### Automatic Handling of Non-Contiguous Indices
|
||||
|
||||
The PyTorchFileRecorder automatically handles non-contiguous indices in model layer names. For
|
||||
example, if the source model contains indices with gaps:
|
||||
|
||||
```
|
||||
"model.layers.0.weight"
|
||||
"model.layers.0.bias"
|
||||
"model.layers.2.weight" // Note the gap (no index 1)
|
||||
"model.layers.2.bias"
|
||||
"model.layers.4.weight"
|
||||
"model.layers.4.bias"
|
||||
```
|
||||
|
||||
The recorder will automatically reindex these to be contiguous while preserving their order:
|
||||
|
||||
```
|
||||
"model.layers.0.weight"
|
||||
"model.layers.0.bias"
|
||||
"model.layers.1.weight" // Reindexed from 2
|
||||
"model.layers.1.bias"
|
||||
"model.layers.2.weight" // Reindexed from 4
|
||||
"model.layers.2.bias"
|
||||
```
|
||||
|
||||
### Partial Model Loading
|
||||
|
||||
You can selectively load weights into a partial model, which is useful for:
|
||||
|
||||
- Loading only the encoder from an encoder-decoder architecture
|
||||
- Fine-tuning specific layers while initializing others randomly
|
||||
- Creating hybrid models combining parts from different sources
|
||||
|
||||
### Specifying the Top-Level Key for state_dict
|
||||
|
||||
Sometimes the
|
||||
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
|
||||
is nested under a top-level key along with other metadata. In this case, you can specify the
|
||||
top-level key in `LoadArgs`:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tiny.en.pt".into())
|
||||
.with_top_level_key("my_state_dict");
|
||||
|
||||
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully")
|
||||
```
|
||||
|
||||
### Support for Enum Modules
|
||||
|
||||
The PyTorchFileRecorder supports models containing enum modules with new-type variants. The enum
|
||||
variant is automatically selected based on the enum variant type, allowing for flexible model
|
||||
architectures.
|
||||
|
||||
## Current Known Issues
|
||||
|
||||
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).
|
||||
@@ -1,278 +0,0 @@
|
||||
# Safetensors Model
|
||||
|
||||
## Introduction
|
||||
|
||||
Burn supports importing model weights from the Safetensors format, a secure and efficient
|
||||
alternative to pickle-based formats. Whether you've trained your model in PyTorch or you want to use
|
||||
a pre-trained model that provides weights in Safetensors format, you can easily import them into
|
||||
Burn.
|
||||
|
||||
This guide demonstrates the complete workflow for exporting models to Safetensors format and
|
||||
importing them into Burn.
|
||||
|
||||
## Exporting Models to Safetensors Format
|
||||
|
||||
To export a PyTorch model to Safetensors format, you'll need the `safetensors` Python library. This
|
||||
library provides a simple API for saving model weights in the Safetensors format.
|
||||
|
||||
### Example: Exporting a PyTorch Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import save_file
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2,2))
|
||||
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Initialize model and ensure it's on CPU
|
||||
model = Net().to(torch.device("cpu"))
|
||||
|
||||
# Extract model weights dictionary
|
||||
model_weights = model.state_dict()
|
||||
|
||||
# Save to Safetensors format
|
||||
save_file(model_weights, "conv2d.safetensors")
|
||||
```
|
||||
|
||||
### Verifying the Export
|
||||
|
||||
You can verify your exported model by viewing the `.safetensors` file in
|
||||
[Netron](https://github.com/lutzroeder/netron), a neural network visualization tool. A correctly
|
||||
exported file will display a flat structure of tensors, similar to a PyTorch `.pt` weights file.
|
||||
|
||||
## Importing Safetensors Models into Burn
|
||||
|
||||
Importing a Safetensors model into Burn involves two main steps:
|
||||
|
||||
1. Defining the model architecture in Burn
|
||||
2. Loading the weights from the Safetensors file
|
||||
|
||||
### Step 1: Define the Model in Burn
|
||||
|
||||
First, you need to create a Burn model that matches the architecture of the model you exported:
|
||||
|
||||
```rust
|
||||
use burn::{
|
||||
nn::conv::{Conv2d, Conv2dConfig},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
conv2: Conv2d<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Net<B> {
|
||||
/// Create a new model.
|
||||
pub fn init(device: &B::Device) -> Self {
|
||||
let conv1 = Conv2dConfig::new([2, 2], [2, 2])
|
||||
.init(device);
|
||||
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
|
||||
.with_bias(false)
|
||||
.init(device);
|
||||
Self { conv1, conv2 }
|
||||
}
|
||||
|
||||
/// Forward pass of the model.
|
||||
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let x = self.conv1.forward(x);
|
||||
self.conv2.forward(x)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Load the Weights
|
||||
|
||||
You have two options for loading the weights:
|
||||
|
||||
#### Option A: Load Dynamically at Runtime
|
||||
|
||||
This approach loads the Safetensors file directly at runtime, requiring the `burn-import`
|
||||
dependency:
|
||||
|
||||
```rust
|
||||
use crate::model;
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::safetensors::SafetensorsFileRecorder;
|
||||
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
fn main() {
|
||||
let device = Default::default();
|
||||
|
||||
// Load weights from Safetensors file
|
||||
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("./conv2d.safetensors".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
// Initialize model and load weights
|
||||
let model = model::Net::<Backend>::init(&device).load_record(record);
|
||||
}
|
||||
```
|
||||
|
||||
#### Option B: Pre-convert to Burn's Binary Format
|
||||
|
||||
This approach converts the Safetensors file to Burn's optimized binary format during build time,
|
||||
removing the runtime dependency on `burn-import`:
|
||||
|
||||
```rust
|
||||
// This code would go in build.rs or a separate tool
|
||||
|
||||
use crate::model;
|
||||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
use burn_import::safetensors::SafetensorsFileRecorder;
|
||||
|
||||
type Backend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
fn convert_model() {
|
||||
let device = Default::default();
|
||||
|
||||
// Load from Safetensors
|
||||
let recorder = SafetensorsFileRecorder::<FullPrecisionSettings>::default();
|
||||
let record = recorder
|
||||
.load("./conv2d.safetensors".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
// Save to Burn's binary format
|
||||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
|
||||
recorder
|
||||
.record(record, "model.mpk".into())
|
||||
.expect("Failed to save model record");
|
||||
}
|
||||
|
||||
// In your application code
|
||||
fn load_model() -> Net<Backend> {
|
||||
let device = Default::default();
|
||||
|
||||
// Load from Burn's binary format
|
||||
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("./model.mpk".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
Net::<Backend>::init(&device).load_record(record)
|
||||
}
|
||||
```
|
||||
|
||||
> **Note**: For examples of pre-converting models, see the `examples/import-model-weights` directory
|
||||
> in the Burn repository.
|
||||
|
||||
## Advanced Configuration Options
|
||||
|
||||
### Framework-Specific Adapters
|
||||
|
||||
When importing Safetensors models, you can specify an adapter type to handle framework-specific
|
||||
tensor transformations. This is crucial when importing models from different ML frameworks, as
|
||||
tensor layouts and naming conventions can vary:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
|
||||
// Create load arguments with framework-specific adapter
|
||||
let load_args = LoadArgs::new("model.safetensors".into())
|
||||
.with_adapter_type(AdapterType::PyTorch); // Default adapter
|
||||
|
||||
// Load with the specified adapter
|
||||
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
```
|
||||
|
||||
#### Available Adapter Types
|
||||
|
||||
| Adapter Type | Description |
|
||||
| --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **PyTorch** (default) | Automatically applies PyTorch-specific transformations:<br>- Transposes weights for linear layers<br>- Renames normalization parameters (weight→gamma, bias→beta) |
|
||||
| **NoAdapter** | Loads tensors directly without any transformations<br>- Useful when importing from frameworks that already match Burn's tensor layout |
|
||||
|
||||
## Troubleshooting and Advanced Features
|
||||
|
||||
### Key Remapping for Different Model Architectures
|
||||
|
||||
If your Burn model structure doesn't match the parameter names in the Safetensors file, you can
|
||||
remap keys using regular expressions:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
|
||||
// Create load arguments with key remapping
|
||||
let load_args = LoadArgs::new("model.safetensors".into())
|
||||
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
|
||||
.with_key_remap("conv\\.(.*)", "$1");
|
||||
|
||||
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
|
||||
let model = Net::<Backend>::init(&device).load_record(record);
|
||||
```
|
||||
|
||||
### Debugging with Key Inspection
|
||||
|
||||
To help with troubleshooting import issues, you can enable debugging to print the original and
|
||||
remapped keys:
|
||||
|
||||
```rust
|
||||
let device = Default::default();
|
||||
|
||||
// Enable debug printing of keys
|
||||
let load_args = LoadArgs::new("model.safetensors".into())
|
||||
.with_key_remap("conv\\.(.*)", "$1")
|
||||
.with_debug_print(); // Print original and remapped keys
|
||||
|
||||
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
```
|
||||
|
||||
### Automatic Handling of Non-Contiguous Indices
|
||||
|
||||
The SafetensorsFileRecorder automatically handles non-contiguous indices in model layer names. For
|
||||
example, if the source model contains indices with gaps:
|
||||
|
||||
```
|
||||
"model.layers.0.weight"
|
||||
"model.layers.0.bias"
|
||||
"model.layers.2.weight" // Note the gap (no index 1)
|
||||
"model.layers.2.bias"
|
||||
"model.layers.4.weight"
|
||||
"model.layers.4.bias"
|
||||
```
|
||||
|
||||
The recorder will automatically reindex these to be contiguous while preserving their order:
|
||||
|
||||
```
|
||||
"model.layers.0.weight"
|
||||
"model.layers.0.bias"
|
||||
"model.layers.1.weight" // Reindexed from 2
|
||||
"model.layers.1.bias"
|
||||
"model.layers.2.weight" // Reindexed from 4
|
||||
"model.layers.2.bias"
|
||||
```
|
||||
|
||||
### Partial Model Loading
|
||||
|
||||
You can selectively load weights into a partial model, which is useful for:
|
||||
|
||||
- Loading only the encoder from an encoder-decoder architecture
|
||||
- Fine-tuning specific layers while initializing others randomly
|
||||
- Creating hybrid models combining parts from different sources
|
||||
|
||||
### Support for Enum Modules
|
||||
|
||||
The SafetensorsFileRecorder supports models containing enum modules with new-type variants. The enum
|
||||
variant is automatically selected based on the enum variant type, allowing for flexible model
|
||||
architectures.
|
||||
@@ -1,12 +1,11 @@
|
||||
# Importing ONNX Models in Burn
|
||||
# ONNX Import
|
||||
|
||||
## Introduction
|
||||
|
||||
As deep learning evolves, interoperability between frameworks becomes crucial. Burn, a modern deep
|
||||
learning framework in Rust, provides robust support for importing models from other popular
|
||||
frameworks. This section focuses on importing
|
||||
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn,
|
||||
enabling you to leverage pre-trained models in your Rust-based deep learning projects.
|
||||
As deep learning evolves, interoperability between frameworks becomes crucial. Burn provides robust
|
||||
support for importing [ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html)
|
||||
models through the [`burn-onnx`](https://github.com/tracel-ai/burn-onnx) crate, enabling you to
|
||||
leverage pre-trained models in your Rust-based deep learning projects.
|
||||
|
||||
## Why Import Models?
|
||||
|
||||
@@ -61,7 +60,7 @@ There are two simple ways to upgrade your ONNX models to the recommended opset v
|
||||
Option 1: Use the provided utility script:
|
||||
|
||||
```
|
||||
uv run --script https://raw.githubusercontent.com/tracel-ai/burn/refs/heads/main/crates/burn-onnx/onnx_opset_upgrade.py
|
||||
uv run --script https://raw.githubusercontent.com/tracel-ai/burn-onnx/refs/heads/main/onnx_opset_upgrade.py
|
||||
```
|
||||
|
||||
Option 2: Use a custom Python script:
|
||||
@@ -94,19 +93,16 @@ First, add the required dependencies to your `Cargo.toml`:
|
||||
```toml
|
||||
[dependencies]
|
||||
burn = { version = "~0.21", features = ["ndarray"] }
|
||||
burn-store = { version = "~0.21", features = ["burnpack"] }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = "~0.21"
|
||||
```
|
||||
|
||||
The `burn-store` crate with the `burnpack` feature is required to load model weights at runtime.
|
||||
|
||||
### Step 2: Update `build.rs`
|
||||
|
||||
In your `build.rs` file:
|
||||
|
||||
```rust
|
||||
```rust, ignore
|
||||
use burn_onnx::ModelGen;
|
||||
|
||||
fn main() {
|
||||
@@ -117,13 +113,13 @@ fn main() {
|
||||
}
|
||||
```
|
||||
|
||||
This generates Rust code and a `.burnpack` weights file from your ONNX model during the build process.
|
||||
This generates Rust code and a `.bpk` weights file from your ONNX model during the build process.
|
||||
|
||||
### Step 3: Modify `mod.rs`
|
||||
|
||||
In your `src/model/mod.rs` file, include the generated code:
|
||||
|
||||
```rust
|
||||
```rust, ignore
|
||||
pub mod my_model {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/my_model.rs"));
|
||||
}
|
||||
@@ -133,7 +129,7 @@ pub mod my_model {
|
||||
|
||||
Now you can use the imported model in your code:
|
||||
|
||||
```rust
|
||||
```rust, ignore
|
||||
use burn::tensor;
|
||||
use burn_ndarray::{NdArray, NdArrayDevice};
|
||||
use model::my_model::Model;
|
||||
@@ -158,7 +154,7 @@ fn main() {
|
||||
|
||||
The `ModelGen` struct provides configuration options:
|
||||
|
||||
```rust
|
||||
```rust, ignore
|
||||
ModelGen::new()
|
||||
.input("path/to/model.onnx")
|
||||
.out_dir("model/")
|
||||
@@ -170,26 +166,27 @@ ModelGen::new()
|
||||
- `input`: Path to the ONNX model file
|
||||
- `out_dir`: Output directory for generated code and weights
|
||||
- `development`: When enabled, generates additional debug files (`.onnx.txt`, `.graph.txt`)
|
||||
- `embed_states`: When enabled, embeds model weights in the binary using `include_bytes!`.
|
||||
Useful for WebAssembly or single-binary deployments. Not recommended for large models.
|
||||
- `embed_states`: When enabled, embeds model weights in the binary using `include_bytes!`. Useful
|
||||
for WebAssembly or single-binary deployments. Not recommended for large models.
|
||||
|
||||
Model weights are stored in `.burnpack` format, which provides efficient serialization and loading.
|
||||
Model weights are stored in Burnpack format (`.bpk`), which provides efficient serialization and
|
||||
loading.
|
||||
|
||||
## Loading and Using Models
|
||||
|
||||
You can load models in several ways:
|
||||
|
||||
```rust
|
||||
```rust, ignore
|
||||
// Load from the output directory with default device (recommended for most use cases)
|
||||
// This automatically loads weights from the .burnpack file
|
||||
// This automatically loads weights from the .bpk file
|
||||
let model = Model::<Backend>::default();
|
||||
|
||||
// Create a new model instance with a specific device
|
||||
// (initializes weights randomly; load weights via `load_from` afterward)
|
||||
let model = Model::<Backend>::new(&device);
|
||||
|
||||
// Load from a specific .burnpack file
|
||||
let model = Model::<Backend>::from_file("path/to/weights.burnpack", &device);
|
||||
// Load from a specific .bpk file
|
||||
let model = Model::<Backend>::from_file("path/to/weights.bpk", &device);
|
||||
|
||||
// Load from embedded weights (if embed_states was true)
|
||||
let model = Model::<Backend>::from_embedded(&device);
|
||||
@@ -200,7 +197,7 @@ let model = Model::<Backend>::from_embedded(&device);
|
||||
Common issues and solutions:
|
||||
|
||||
1. **Unsupported ONNX operator**: Check the
|
||||
[list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/SUPPORTED-ONNX-OPS.md).
|
||||
[list of supported ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md).
|
||||
You may need to simplify your model or wait for support.
|
||||
|
||||
2. **Build errors**: Ensure your `burn-onnx` version matches your Burn version and verify the ONNX
|
||||
@@ -217,13 +214,23 @@ Common issues and solutions:
|
||||
|
||||
## Examples and Resources
|
||||
|
||||
For practical examples, check out:
|
||||
For practical examples, check out the
|
||||
[burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples):
|
||||
|
||||
1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
|
||||
2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)
|
||||
1. [ONNX Inference](https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference) -
|
||||
MNIST inference example
|
||||
2. [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) -
|
||||
SqueezeNet running in the browser via WebAssembly
|
||||
3. [Raspberry Pi Pico](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico) -
|
||||
Embedded deployment example
|
||||
|
||||
These demonstrate real-world usage of ONNX import in Burn projects.
|
||||
|
||||
For contributors looking to add support for new ONNX operators:
|
||||
|
||||
- [Development Guide](https://github.com/tracel-ai/burn-onnx/blob/main/DEVELOPMENT-GUIDE.md) -
|
||||
Step-by-step guide for implementing new operators
|
||||
|
||||
## Conclusion
|
||||
|
||||
Importing ONNX models into Burn combines the vast ecosystem of pre-trained models with Burn's
|
||||
@@ -231,10 +238,5 @@ performance and Rust's safety features. Following this guide, you can seamlessly
|
||||
models into your Burn projects for inference, fine-tuning, or further development.
|
||||
|
||||
The `burn-onnx` crate is actively developed, with ongoing work to support more ONNX operators and
|
||||
improve performance. Stay tuned to the Burn repository for updates!
|
||||
|
||||
---
|
||||
|
||||
> 🚨**Note**: The `burn-onnx` crate is in active development. For the most up-to-date information
|
||||
> on supported ONNX operators, please refer to the
|
||||
> [official documentation](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/SUPPORTED-ONNX-OPS.md).
|
||||
improve performance. Visit the [burn-onnx repository](https://github.com/tracel-ai/burn-onnx) for
|
||||
updates and to contribute!
|
||||
@@ -20,11 +20,11 @@ advanced user or a beginner. We have crafted some sections for you:
|
||||
loops, fine-tuning your models to meet your specific requirements. This section empowers you to
|
||||
harness Burn's flexibility to its fullest.
|
||||
|
||||
- [Saving & Loading Models](./saving-and-loading.md): Learn how to easily save and load your trained
|
||||
models.
|
||||
- [Saving & Loading Models](./saving-and-loading.md): Learn how to save and load your trained
|
||||
models, including importing weights from PyTorch and SafeTensors formats.
|
||||
|
||||
- [Importing Models](./import): Learn how to import ONNX and PyTorch models, expanding your
|
||||
compatibility with other deep learning ecosystems.
|
||||
- [ONNX Import](./onnx-import.md): Learn how to import ONNX models using the
|
||||
[burn-onnx](https://github.com/tracel-ai/burn-onnx) crate.
|
||||
|
||||
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md): Get started quickly with
|
||||
ready-to-use models and pre-trained weights.
|
||||
|
||||
@@ -56,51 +56,359 @@ Afterwards, the model can just as easily be loaded from the record saved on disk
|
||||
|
||||
```rust, ignore
|
||||
// Load model record on the backend's default device
|
||||
let record: ModelRecord<MyBackend> = NamedMpkFileRecorder::<FullPrecisionSettings>::new()
|
||||
let record: ModelRecord<MyBackend> =
|
||||
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
|
||||
.load(model_path.into(), &device)
|
||||
.expect("Should be able to load the model weights from the provided file");
|
||||
.expect("Could not load model weights");
|
||||
|
||||
// Initialize a new model with the loaded record/weights
|
||||
let model = Model::init(&device).load_record(record);
|
||||
```
|
||||
|
||||
## No Storage, No Problem!
|
||||
## Model Weight Store
|
||||
|
||||
For applications where file storage may not be available (or desired) at runtime, you can use the
|
||||
`BinBytesRecorder`.
|
||||
While the Recorder API works well for basic saving and loading, `burn-store` was introduced to
|
||||
address its limitations around memory efficiency and flexibility. It provides zero-copy
|
||||
memory-mapped loading, cross-framework interoperability (PyTorch and SafeTensors), key remapping,
|
||||
partial loading, and filtering. The `burn-store`
|
||||
crate is intended to eventually replace the Recorder API, but since it was recently released, both
|
||||
APIs are supported.
|
||||
|
||||
In the previous examples we used a `FileRecorder` based on the MessagePack format, which could be
|
||||
replaced with [another file recorder](./building-blocks/record.md#recorder) of your choice. To embed
|
||||
a model as part of your runtime application, first save the model to a binary file with
|
||||
`BinFileRecorder`.
|
||||
### Supported Formats
|
||||
|
||||
| Format | Extension | Description |
|
||||
| --------------- | -------------- | ----------------------------------------------------------------------------------------- |
|
||||
| **Burnpack** | `.bpk` | Burn's native format with fast loading, zero-copy support, and training state persistence |
|
||||
| **SafeTensors** | `.safetensors` | Industry-standard format from Hugging Face for secure tensor serialization |
|
||||
| **PyTorch** | `.pt`, `.pth` | Direct loading of PyTorch model weights (read-only) |
|
||||
|
||||
### Saving a Model
|
||||
|
||||
```rust, ignore
|
||||
// Save model in binary format with full precision
|
||||
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
|
||||
model
|
||||
.save_file(model_path, &recorder)
|
||||
.expect("Should be able to save the model");
|
||||
use burn_store::{ModuleSnapshot, BurnpackStore};
|
||||
|
||||
// Save to Burnpack (recommended)
|
||||
let mut store = BurnpackStore::from_file("model.bpk");
|
||||
model.save_into(&mut store)?;
|
||||
|
||||
// Or save to SafeTensors
|
||||
use burn_store::SafetensorsStore;
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors");
|
||||
model.save_into(&mut store)?;
|
||||
```
|
||||
|
||||
Then, in your final application, include the model and use the `BinBytesRecorder` to load it.
|
||||
|
||||
Embedding the model as part of your application is especially useful for smaller models but not
|
||||
recommended for very large models as it would significantly increase the binary size as well as
|
||||
consume a lot more memory at runtime.
|
||||
### Loading a Model
|
||||
|
||||
```rust, ignore
|
||||
// Include the model file as a reference to a byte array
|
||||
static MODEL_BYTES: &[u8] = include_bytes!("path/to/model.bin");
|
||||
use burn_store::{ModuleSnapshot, BurnpackStore};
|
||||
|
||||
// Load model binary record in full precision
|
||||
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
|
||||
.load(MODEL_BYTES.to_vec(), device)
|
||||
.expect("Should be able to load model the model weights from bytes");
|
||||
let device = Default::default();
|
||||
let mut model = MyModel::init(&device);
|
||||
|
||||
// Load that record with the model
|
||||
model.load_record(record);
|
||||
// Load from Burnpack
|
||||
let mut store = BurnpackStore::from_file("model.bpk");
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
This example assumes that the model was already created before loading the model record. If instead
|
||||
you want to skip the random initialization and directly initialize the weights with the provided
|
||||
record, you could adapt this like the [previous example](#initialization-from-recorded-weights).
|
||||
### Loading from PyTorch
|
||||
|
||||
You can load weights directly from PyTorch `.pt` files:
|
||||
|
||||
```rust, ignore
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
|
||||
let mut model = MyModel::init(&device);
|
||||
let mut store = PytorchStore::from_file("pytorch_model.pt");
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
#### Exporting from PyTorch
|
||||
|
||||
Save only the model weights (state_dict), not the entire model:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(2, 2, (2, 2))
|
||||
self.conv2 = nn.Conv2d(2, 2, (2, 2), bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.conv1(x))
|
||||
|
||||
model = Net()
|
||||
torch.save(model.state_dict(), "model.pt") # Correct: save state_dict
|
||||
# torch.save(model, "model.pt") # Wrong: saves entire model
|
||||
```
|
||||
|
||||
#### Accessing Nested State Dicts
|
||||
|
||||
Some PyTorch checkpoints nest the state_dict under a key:
|
||||
|
||||
```rust, ignore
|
||||
let mut store = PytorchStore::from_file("checkpoint.pt")
|
||||
.with_top_level_key("state_dict");
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
### Loading from SafeTensors
|
||||
|
||||
For SafeTensors files exported from PyTorch, use the adapter for proper weight transformation:
|
||||
|
||||
```rust, ignore
|
||||
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
|
||||
|
||||
let mut model = MyModel::init(&device);
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors")
|
||||
.with_from_adapter(PyTorchToBurnAdapter);
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
For SafeTensors files created by Burn, no adapter is needed:
|
||||
|
||||
```rust, ignore
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors");
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
#### Exporting from PyTorch to SafeTensors
|
||||
|
||||
```python
|
||||
from safetensors.torch import save_file
|
||||
|
||||
model = Net()
|
||||
save_file(model.state_dict(), "model.safetensors")
|
||||
```
|
||||
|
||||
### Saving for PyTorch Compatibility
|
||||
|
||||
Use the adapter when saving for PyTorch consumption:
|
||||
|
||||
```rust, ignore
|
||||
use burn_store::{BurnToPyTorchAdapter, SafetensorsStore};
|
||||
|
||||
let mut store = SafetensorsStore::from_file("for_pytorch.safetensors")
|
||||
.with_to_adapter(BurnToPyTorchAdapter)
|
||||
.skip_enum_variants(true);
|
||||
model.save_into(&mut store)?;
|
||||
```
|
||||
|
||||
### Handling Load Results
|
||||
|
||||
The `load_from` method returns detailed information about the loading process:
|
||||
|
||||
```rust, ignore
|
||||
let result = model.load_from(&mut store)?;
|
||||
|
||||
// Print a formatted summary with suggestions
|
||||
println!("{}", result);
|
||||
|
||||
// Or inspect individual fields
|
||||
println!("Applied: {} tensors", result.applied.len());
|
||||
println!("Missing: {:?}", result.missing);
|
||||
println!("Errors: {:?}", result.errors);
|
||||
|
||||
if result.is_success() {
|
||||
println!("All tensors loaded successfully");
|
||||
}
|
||||
```
|
||||
|
||||
### Adding Metadata
|
||||
|
||||
Burnpack and SafeTensors support custom metadata:
|
||||
|
||||
```rust, ignore
|
||||
let mut store = BurnpackStore::from_file("model.bpk")
|
||||
.metadata("version", "1.0")
|
||||
.metadata("description", "My trained model")
|
||||
.metadata("epochs", "100");
|
||||
model.save_into(&mut store)?;
|
||||
```
|
||||
|
||||
### Advanced Features
|
||||
|
||||
#### Key Remapping
|
||||
|
||||
Remap parameter names using regex patterns when model structures don't match:
|
||||
|
||||
```rust, ignore
|
||||
let mut store = PytorchStore::from_file("model.pt")
|
||||
// Remove prefix: "model.conv1.weight" -> "conv1.weight"
|
||||
.with_key_remapping(r"^model\.", "")
|
||||
// Rename: "layer1" -> "encoder.layer1"
|
||||
.with_key_remapping(r"^layer", "encoder.layer");
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
For complex remapping:
|
||||
|
||||
```rust, ignore
|
||||
use burn_store::KeyRemapper;
|
||||
|
||||
let remapper = KeyRemapper::new()
|
||||
.add_pattern(r"^transformer\.h\.(\d+)\.", "transformer.layer$1.")?
|
||||
.add_pattern(r"\.attn\.", ".attention.")?;
|
||||
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors")
|
||||
.remap(remapper);
|
||||
```
|
||||
|
||||
#### Partial Loading
|
||||
|
||||
Load weights even when some tensors are missing:
|
||||
|
||||
```rust, ignore
|
||||
let mut store = PytorchStore::from_file("pretrained.pt")
|
||||
.allow_partial(true);
|
||||
|
||||
let result = model.load_from(&mut store)?;
|
||||
println!("Missing (initialized randomly): {:?}", result.missing);
|
||||
```
|
||||
|
||||
#### Filtering Tensors
|
||||
|
||||
Load or save only specific layers:
|
||||
|
||||
```rust, ignore
|
||||
// Load only encoder layers
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors")
|
||||
.with_regex(r"^encoder\..*")
|
||||
.allow_partial(true);
|
||||
|
||||
// Save only encoder layers
|
||||
let mut store = SafetensorsStore::from_file("encoder.safetensors")
|
||||
.with_regex(r"^encoder\..*");
|
||||
model.save_into(&mut store)?;
|
||||
|
||||
// Multiple patterns (OR logic)
|
||||
let mut store = SafetensorsStore::from_file("model.safetensors")
|
||||
.with_regex(r"^encoder\..*") // encoder tensors
|
||||
.with_regex(r".*\.bias$") // OR any bias tensors
|
||||
.with_full_path("decoder.scale"); // OR specific tensor
|
||||
```
|
||||
|
||||
#### Non-Contiguous Layer Indices
|
||||
|
||||
PyTorch `nn.Sequential` with mixed layers creates non-contiguous indices. `PytorchStore`
|
||||
automatically remaps these:
|
||||
|
||||
```
|
||||
PyTorch: fc.0.weight, fc.2.weight, fc.4.weight (gaps from ReLU layers)
|
||||
Burn: fc.0.weight, fc.1.weight, fc.2.weight (contiguous)
|
||||
```
|
||||
|
||||
This is enabled by default. Disable if needed:
|
||||
|
||||
```rust, ignore
|
||||
let mut store = PytorchStore::from_file("model.pt")
|
||||
.map_indices_contiguous(false);
|
||||
```
|
||||
|
||||
#### Zero-Copy Loading
|
||||
|
||||
For embedded models or large files, use zero-copy loading to avoid memory copies:
|
||||
|
||||
```rust, ignore
|
||||
// Embedded model (compile-time)
|
||||
static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
|
||||
let mut store = BurnpackStore::from_static(MODEL_DATA);
|
||||
model.load_from(&mut store)?;
|
||||
|
||||
// Large file (memory-mapped)
|
||||
let mut store = BurnpackStore::from_file("large_model.bpk")
|
||||
.zero_copy(true);
|
||||
model.load_from(&mut store)?;
|
||||
```
|
||||
|
||||
#### Direct Tensor Access
|
||||
|
||||
Inspect tensors without loading into a model:
|
||||
|
||||
```rust, ignore
|
||||
use burn_store::ModuleStore;
|
||||
|
||||
let mut store = PytorchStore::from_file("model.pt");
|
||||
|
||||
// List all tensor names
|
||||
let names = store.keys()?;
|
||||
|
||||
// Get specific tensor
|
||||
if let Some(snapshot) = store.get_snapshot("encoder.layer0.weight")? {
|
||||
println!("Shape: {:?}, DType: {:?}", snapshot.shape, snapshot.dtype);
|
||||
}
|
||||
```
|
||||
|
||||
#### Model Surgery
|
||||
|
||||
Transfer weights between models:
|
||||
|
||||
```rust, ignore
|
||||
use burn_store::{ModuleSnapshot, PathFilter};
|
||||
|
||||
// Transfer all weights
|
||||
let snapshots = model1.collect(None, None, false);
|
||||
model2.apply(snapshots, None, None, false);
|
||||
|
||||
// Transfer only encoder weights
|
||||
let filter = PathFilter::new().with_regex(r"^encoder\..*");
|
||||
let snapshots = model1.collect(Some(filter.clone()), None, false);
|
||||
model2.apply(snapshots, Some(filter), None, false);
|
||||
```
|
||||
|
||||
### API Reference
|
||||
|
||||
#### Builder Methods
|
||||
|
||||
| Category | Method | Description |
|
||||
| ------------- | ------------------------------ | ---------------------------- |
|
||||
| **Filtering** | `with_regex(pattern)` | Filter by regex pattern |
|
||||
| | `with_full_path(path)` | Include specific tensor |
|
||||
| | `with_predicate(fn)` | Custom filter logic |
|
||||
| **Remapping** | `with_key_remapping(from, to)` | Regex-based renaming |
|
||||
| | `remap(KeyRemapper)` | Complex remapping rules |
|
||||
| **Adapters** | `with_from_adapter(adapter)` | Loading transformations |
|
||||
| | `with_to_adapter(adapter)` | Saving transformations |
|
||||
| **Config** | `allow_partial(bool)` | Continue on missing tensors |
|
||||
| | `with_top_level_key(key)` | Access nested dict (PyTorch) |
|
||||
| | `skip_enum_variants(bool)` | Skip enum variants in paths |
|
||||
| | `map_indices_contiguous(bool)` | Remap non-contiguous indices |
|
||||
| | `metadata(key, value)` | Add custom metadata |
|
||||
| | `zero_copy(bool)` | Enable zero-copy loading |
|
||||
|
||||
#### Direct Access Methods
|
||||
|
||||
| Method | Description |
|
||||
| --------------------- | -------------------------------- |
|
||||
| `keys()` | Get ordered list of tensor names |
|
||||
| `get_all_snapshots()` | Get all tensors as BTreeMap |
|
||||
| `get_snapshot(name)` | Get specific tensor by name |
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
#### Common Issues
|
||||
|
||||
1. **"Missing source values" error**: You saved the entire PyTorch model instead of the state_dict.
|
||||
Re-export with `torch.save(model.state_dict(), "model.pt")`.
|
||||
|
||||
2. **Shape mismatch**: Your Burn model doesn't match the source architecture. Verify layer
|
||||
configurations (channels, kernel sizes, bias settings).
|
||||
|
||||
3. **Key not found**: Parameter names don't match. Use `with_key_remapping()` or inspect keys:
|
||||
|
||||
```rust, ignore
|
||||
let store = PytorchStore::from_file("model.pt");
|
||||
println!("Available keys: {:?}", store.keys()?);
|
||||
```
|
||||
|
||||
#### Inspecting Files
|
||||
|
||||
Use [Netron](https://github.com/lutzroeder/netron) to visualize `.pt` and `.safetensors` files.
|
||||
|
||||
For Burnpack files:
|
||||
|
||||
```bash
|
||||
cargo run --example burnpack_inspect model.bpk
|
||||
```
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
- [Tensor](./project-architecture/tensor.md)
|
||||
- [Backend](./project-architecture/backend.md)
|
||||
- [Guides for Contributors](./guides/README.md)
|
||||
- [ONNX to Burn: Development Guide](./guides/onnx-to-burn-conversion-tool.md)
|
||||
- [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md)
|
||||
- [Submitting Examples to Burn](./guides/submitting-examples.md)
|
||||
- [Frequently Encountered Issues](./frequently-encountered-issues/README.md)
|
||||
|
||||
@@ -17,27 +17,3 @@ If you encounter this, swap out the `assert_eq!` in the failing test for
|
||||
`tensor1.to_data().assert_approx_eq` with `3` as the second argument. The second arguments specifies
|
||||
the level of precision: `3` is equivalent to a less than 10<sup>-3</sup> (0.001) difference between
|
||||
the elements of the two tensors.
|
||||
|
||||
## Mismatched types and missing functions
|
||||
|
||||
```sh
|
||||
error[E0308]: mismatched types --> {burn_dir}/target/debug/build/onnx-tests-fed12aaf3671687f/out/model/pow.rs:48:45 | 48 | let pow1_out1 = input1.clone().powf(input1); | ---- ^^^^^^ expected `f32`, found `Tensor<B, 4>` | | | arguments to this method are incorrect | = note: expected type `f32` found struct `Tensor<B, 4>`
|
||||
|
||||
note: method defined here --> {burn_dir}/burn-tensor/src/tensor/api/float.rs:65:12 | 65 | pub fn powf(self, value: f32) -> Self { | ^^^^
|
||||
|
||||
error[E0599]: no method named `powf_scalar` found for struct `Tensor` in the current scope --> {burn_dir}/target/debug/build/onnx-tests-fed12aaf3671687f/out/model/pow.rs:50:35 | 50 | let pow2_out1 = pow1_out1.powf_scalar(cast1_out1); | ^^^^^^^^^^^ method not found in `Tensor<B, 4>`
|
||||
|
||||
error[E0599]: no method named `powi` found for struct `Tensor` in the current scope --> {burn_dir}/target/debug/build/onnx-tests-fed12aaf3671687f/out/model/pow_int.rs:49:40 | 49 | let pow1_out1 = input1.clone().powi(input1); | ^^^^ method not found in `Tensor<B, 4, Int>` Some errors have detailed explanations: E0308, E0599.
|
||||
For more information about an error, try `rustc --explain E0308`. error: could not compile `onnx-tests` (test "onnx_tests") due to 3 previous errors
|
||||
```
|
||||
|
||||
If you are getting this error, you probably didn't implement your operator for the actual Tensor
|
||||
struct. This issue was encountered when adding the Pow operator. The operation was added to the
|
||||
`FloatTensorOps` and `IntTensorOp` traits, but not for the numeric trait (under
|
||||
`burn-tensor/src/tensor/api/numeric.rs`). This, coupled with `powf` existing prior to the PR though
|
||||
only for scalar values (which had been renamed, just not in the right place), led to this confusing
|
||||
issue where it looked like the function was found, but the type was wrong. If that's the case, make
|
||||
sure that it's implemented for the appropriate type, in this case `Float` under
|
||||
[crates/burn-tensor/src/tensor/api/numeric.rs](https://github.com/tracel-ai/burn/blob/1235b06e25e39a6ee5a4ac59f7f1d3da2ddb9bc3/crates/burn-tensor/src/tensor/api/numeric.rs),
|
||||
and calling the `TensorOp.foo_op` defined under
|
||||
[crates/burn-tensor/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/1235b06e25e39a6ee5a4ac59f7f1d3da2ddb9bc3/crates/burn-tensor/src/tensor/ops/tensor.rs)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# Guides for Contributors
|
||||
|
||||
The following guides are meant to help contributors accomplish specific tasks, such as adding new operations to Burn or generating test models for `burn-import`.
|
||||
The following guides are meant to help contributors accomplish specific tasks, such as adding new operations to Burn.
|
||||
@@ -1,480 +0,0 @@
|
||||
# ONNX to Burn: Development Guide
|
||||
|
||||
This guide offers in-depth design insights and step-by-step procedures for developers working on the
|
||||
ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep
|
||||
learning framework written in Rust. It converts ONNX models to Rust source code and model
|
||||
weights to `.burnpack` files.
|
||||
|
||||
For an introduction to ONNX import in Burn, see
|
||||
[this section of the Burn book](https://burn.dev/books/burn/import/onnx-model.html).
|
||||
|
||||
## Design Overview
|
||||
|
||||
### Design Goals
|
||||
|
||||
- Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
|
||||
- Convert ONNX model weights to Burn state files.
|
||||
- Support ONNX models generated by PyTorch (ONNX Opset 16+ recommended for best compatibility).
|
||||
- Produce easy-to-understand and modifiable models.
|
||||
- Ensure the generated models are trainable using Burn APIs.
|
||||
|
||||
### Design Decisions
|
||||
|
||||
**Core Principles:**
|
||||
|
||||
- **Op/Node-Centric Design**: Built around individual operations and nodes for better scalability as
|
||||
more operators are added
|
||||
- **Opset-Aware Processing**: Processors accept opset parameters for flexible behavior across
|
||||
different ONNX versions
|
||||
- **Constants-First Approach**: All ONNX initializers are treated as constant nodes initially,
|
||||
providing a uniform starting point
|
||||
- **Native Type Integration**: Direct use of `burn_tensor::TensorData` and `Dtype` for efficiency,
|
||||
consistency, and future mmap support
|
||||
- **Multi-Phase Pipeline**: Explicit transformation phases (initialization → conversion → type
|
||||
inference → post-processing → finalization) for better visibility and maintainability
|
||||
- **Graph Input Name Preservation**: Sanitized ONNX names are preserved for easier development and
|
||||
troubleshooting
|
||||
|
||||
**Separation of Concerns:**
|
||||
|
||||
- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process
|
||||
- Ensure operator behavior consistency across different OpSet versions
|
||||
- Exclude any ONNX/Protobuf-specific logic from the Burn graph
|
||||
- **Feature Support Validation**: The `onnx-ir` crate should extract and preserve all ONNX attributes
|
||||
faithfully, even if Burn does not yet support them. Rejection of unsupported features should happen
|
||||
in `burn-onnx` during code generation, not in `onnx-ir` during configuration extraction. This
|
||||
allows `onnx-ir` to be reused by other projects that may have different feature support
|
||||
|
||||
The conversion process involves three main stages:
|
||||
|
||||
1. Convert ONNX model to Intermediate Representation (IR) via 5-phase pipeline.
|
||||
2. Translate IR to a Burn graph.
|
||||
3. Generate Rust source code from the Burn graph.
|
||||
|
||||
## Adding New Operators
|
||||
|
||||
To extend `burn-onnx` with support for new ONNX operators, follow these steps:
|
||||
|
||||
1. **Create PyTorch Script**: Place a PyTorch script using the new operator under
|
||||
`crates/burn-onnx/onnx-tests/tests/<op>/<op>.py`. Make sure to print both input and output
|
||||
tensors for end-to-end testing.
|
||||
|
||||
2. **Generate ONNX Model**: Run the PyTorch script to produce an ONNX model.
|
||||
|
||||
3. **Visualize ONNX Model**: Use [Netron](https://github.com/lutzroeder/netron) to verify the ONNX
|
||||
model contains the expected operators.
|
||||
|
||||
4. **Generate IR and Burn Graph**: Navigate to
|
||||
[crates/burn-onnx/](https://github.com/tracel-ai/burn/tree/main/crates/burn-onnx) and run:
|
||||
|
||||
```
|
||||
cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
|
||||
```
|
||||
|
||||
5. **Implement Missing Operators**: If you encounter an error stating that an operator is
|
||||
unsupported, [implement it](#implementing-a-new-operator). The `./out/my-model.graph.txt` should
|
||||
provide relevant information.
|
||||
|
||||
6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds
|
||||
the Burn model in Rust code, and `my-model.burnpack` contains the model weights.
|
||||
|
||||
7. **Integration Test**: Include the test in the `tests/<op_name>/mod.rs` file in the
|
||||
[crates/burn-onnx/onnx-tests/tests/](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/onnx-tests/tests/)
|
||||
directory. Further details can be found in the
|
||||
[onnx-tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/onnx-tests/README.md).
|
||||
|
||||
## Implementing a New Operator
|
||||
|
||||
To extend the capabilities of the Burn library by supporting new operations imported from ONNX
|
||||
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
|
||||
implementation of the `Squeeze` operation to illustrate points as needed. All file/directory paths
|
||||
are relative to the root of the burn repository.
|
||||
|
||||
### Step 1: Node Processor Implementation in onnx-ir
|
||||
|
||||
The `onnx-ir` crate handles the Intermediate Representation (IR) of ONNX models using a
|
||||
processor-based architecture. For each operation:
|
||||
|
||||
1. **Create a node module** in `crates/onnx-ir/src/node/<operation_name>.rs`. This file should
|
||||
contain:
|
||||
- **Configuration struct**: Define operation-specific parameters (e.g., `SqueezeConfig`)
|
||||
- **Processor struct**: Implement `NodeProcessor` trait (marked as `pub(crate)`)
|
||||
- The processor handles:
|
||||
- **Input/output specification**: Define expected inputs and outputs via `NodeSpec`
|
||||
- **Type inference**: Infer output types from inputs and configuration
|
||||
- **Configuration extraction**: Extract operation parameters from ONNX attributes
|
||||
- **Node construction**: Build the final `Node` enum variant with config
|
||||
|
||||
2. **Make the module visible** in `crates/onnx-ir/src/node/mod.rs`:
|
||||
|
||||
```rust
|
||||
pub mod squeeze;
|
||||
```
|
||||
|
||||
3. **Create a node struct** in your module file (e.g., `squeeze.rs`) with the standard fields:
|
||||
|
||||
```rust
|
||||
use onnx_ir_derive::NodeBuilder;
|
||||
|
||||
#[derive(Debug, Clone, NodeBuilder)]
|
||||
pub struct SqueezeNode {
|
||||
pub name: String,
|
||||
pub inputs: Vec<Argument>,
|
||||
pub outputs: Vec<Argument>,
|
||||
pub config: SqueezeConfig,
|
||||
}
|
||||
```
|
||||
|
||||
The `NodeBuilder` derive macro generates a test builder (e.g., `SqueezeNodeBuilder`) with methods
|
||||
for constructing nodes in tests.
|
||||
|
||||
4. **Add to the macro invocation** in `crates/onnx-ir/src/ir/node.rs` by adding a mapping to the
|
||||
`define_node_enum!` macro:
|
||||
|
||||
```rust
|
||||
define_node_enum! {
|
||||
// ... other variants
|
||||
Squeeze => squeeze::SqueezeNode,
|
||||
// ... more variants
|
||||
}
|
||||
```
|
||||
|
||||
This single macro invocation generates both the `NodeType` enum (for parsing) and the `Node` enum
|
||||
(with tuple variants wrapping node structs) from a single source of truth.
|
||||
|
||||
5. **Register your processor** in `crates/onnx-ir/src/registry.rs` by adding it to the
|
||||
`with_standard_processors()` function:
|
||||
```rust
|
||||
registry.register("Squeeze", Box::new(squeeze::SqueezeProcessor));
|
||||
```
|
||||
|
||||
For example, the squeeze operation in `crates/onnx-ir/src/node/squeeze.rs` contains:
|
||||
|
||||
- A `SqueezeConfig` struct with operation parameters (axes)
|
||||
- A `SqueezeProcessor` struct (marked `pub(crate)`) that implements `NodeProcessor`
|
||||
- The `node_spec()` method defines input/output requirements
|
||||
- The `process()` method extracts config and constructs the `Node::Squeeze` variant
|
||||
|
||||
### Step 2: Code Generation in burn-onnx
|
||||
|
||||
1. Create a new file named `<operation_name>.rs` in the `crates/burn-onnx/src/burn/node/`
|
||||
directory. This file implements code generation for your operation by implementing the
|
||||
`NodeCodegen` trait directly on the onnx-ir node type.
|
||||
|
||||
2. Implement the `NodeCodegen<PS>` trait for the onnx-ir node type. This trait defines how the node
|
||||
generates Rust code during the graph compilation process:
|
||||
|
||||
```rust
|
||||
use super::prelude::*;
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for onnx_ir::squeeze::SqueezeNode {
|
||||
fn inputs(&self) -> &[Argument] {
|
||||
&self.inputs
|
||||
}
|
||||
|
||||
fn outputs(&self) -> &[Argument] {
|
||||
&self.outputs
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
|
||||
let input_arg = self.inputs.first().unwrap();
|
||||
let output_arg = self.outputs.first().unwrap();
|
||||
|
||||
// Use scope.arg() to handle Tensor/Scalar/Shape arguments automatically
|
||||
let input = scope.arg(input_arg);
|
||||
let output = arg_to_ident(output_arg);
|
||||
|
||||
// Access node configuration
|
||||
match &self.config.axes {
|
||||
Some(axes) => {
|
||||
let axes_values: Vec<_> = axes.iter().map(|&i| {
|
||||
proc_macro2::Literal::i64_suffixed(i)
|
||||
}).collect();
|
||||
quote! {
|
||||
let #output = #input.squeeze_dims(&[#(#axes_values),*]);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Get output rank from type inference
|
||||
let output_rank = match &output_arg.ty {
|
||||
ArgType::Tensor(t) => t.rank,
|
||||
_ => panic!("Expected tensor output"),
|
||||
};
|
||||
quote! {
|
||||
let #output = #input.squeeze::<#output_rank>();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Key methods to implement:
|
||||
- `inputs(&self)` - Returns references to input arguments (usually just `&self.inputs`)
|
||||
- `outputs(&self)` - Returns references to output arguments (usually just `&self.outputs`)
|
||||
- `forward(&self, scope)` - Generates Rust code for the operation using the `quote!` macro
|
||||
- `field(&self)` - (Optional) Declares module fields for parameters like weights
|
||||
- `collect_snapshots(&self, field_name)` - (Optional) Collects tensor snapshots for burnpack serialization
|
||||
|
||||
3. Use helper utilities from `argument_helpers.rs`:
|
||||
- `scope.arg(argument)` - Automatically handles Tensor/Scalar/Shape with proper cloning
|
||||
- `arg_to_ident(argument)` - Converts argument to identifier for code generation
|
||||
|
||||
4. Add unit tests using snapshot testing to verify the generated code. These tests typically use the
|
||||
`insta` crate and test helper functions to validate the generated code:
|
||||
|
||||
```rust
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::test_helpers::*;
|
||||
use insta::assert_snapshot;
|
||||
use onnx_ir::squeeze::SqueezeNodeBuilder;
|
||||
|
||||
#[test]
|
||||
fn test_squeeze_forward() {
|
||||
let node = SqueezeNodeBuilder::new("squeeze1")
|
||||
.input_tensor("input", 3, DType::F32)
|
||||
.output_tensor("output", 2, DType::F32)
|
||||
.axes(vec![1])
|
||||
.build();
|
||||
let code = codegen_forward_default(&node);
|
||||
assert_snapshot!(code, @"let output = input.squeeze_dims(&[1i64]);");
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 3: Register in Module System
|
||||
|
||||
Add the module declaration to `crates/burn-onnx/src/burn/node/mod.rs`:
|
||||
|
||||
```rust
|
||||
// ... other node modules
|
||||
pub(crate) mod squeeze;
|
||||
// ... more node modules
|
||||
```
|
||||
|
||||
The modules are automatically made visible through re-exports in the same file.
|
||||
|
||||
### Step 4: Register in Code Generation Dispatch
|
||||
|
||||
Add your operation to the dispatch macro in `crates/burn-onnx/src/burn/node_codegen.rs`. The
|
||||
`impl_node_codegen_dispatch!` macro generates the trait implementation that dispatches to your
|
||||
node-specific code.
|
||||
|
||||
Add the node variant name (as defined in `onnx-ir`'s `Node` enum) to the macro invocation:
|
||||
|
||||
```rust
|
||||
impl_node_codegen_dispatch! {
|
||||
// ... other operations
|
||||
Squeeze, // Add your operation here (matches Node::Squeeze variant)
|
||||
// ... more operations
|
||||
}
|
||||
```
|
||||
|
||||
The macro automatically generates:
|
||||
|
||||
- Dispatch implementation for `NodeCodegen<PS>` on `onnx_ir::Node`
|
||||
- All required trait methods (`inputs`, `outputs`, `forward`, `field`, etc.)
|
||||
- Pattern matching to route to your node-specific implementation
|
||||
|
||||
### Step 5: Processor Implementation
|
||||
|
||||
The `NodeProcessor` trait defines how operations are processed in onnx-ir. Each processor must
|
||||
implement:
|
||||
|
||||
1. **Associated type**: `type Config` - Define your configuration struct (use `()` if no config)
|
||||
2. **`infer_types()`** - Infer output types from inputs and config (required)
|
||||
3. **`build_node()`** - Construct the node struct and wrap it in the `Node` enum variant (required)
|
||||
4. **`extract_config()`** - Extract config from attributes/inputs (override if Config != `()`)
|
||||
5. **`spec()`** - Define opset and input/output requirements (optional)
|
||||
6. **`lift_constants()`** - Request constant lifting for inputs (optional)
|
||||
|
||||
Example `build_node()` implementation:
|
||||
|
||||
```rust
|
||||
fn build_node(&self, builder: RawNode, opset: usize) -> Node {
|
||||
let config = self.extract_config(&builder, opset).expect("Config extraction failed");
|
||||
Node::Squeeze(SqueezeNode {
|
||||
name: builder.name,
|
||||
inputs: builder.inputs,
|
||||
outputs: builder.outputs,
|
||||
config,
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
Note: `RawNode` is the intermediate node representation used during processing. The `build_node()`
|
||||
method converts it into the final typed `Node` enum variant.
|
||||
|
||||
For complete examples, see existing processors:
|
||||
|
||||
- **Simple operation**: `crates/onnx-ir/src/node/softmax.rs`
|
||||
- **With constant inputs**: `crates/onnx-ir/src/node/squeeze.rs`
|
||||
- **Complex operation**: `crates/onnx-ir/src/node/conv2d.rs`
|
||||
|
||||
See [NodeProcessor Trait](#nodeprocessor-trait) for the complete trait definition.
|
||||
|
||||
### Step 6: Add Newly Supported Op!
|
||||
|
||||
As a reward, add an extra check to `crates/burn-onnx/SUPPORTED-ONNX-OPS.md`!
|
||||
|
||||
### Constant Lifting
|
||||
|
||||
The onnx-ir pipeline automatically handles constant lifting during the post-processing phase.
|
||||
"Lifting" constants means making constant values directly accessible on node inputs via
|
||||
`Argument::value()`, instead of requiring a separate graph traversal to find a Constant node.
|
||||
|
||||
**When to use**: If your operation takes constant inputs (e.g., weights in Conv1d, shape tensors in
|
||||
Reshape, axes in Squeeze), access them via `node.inputs[N].value()` in your `extract_config()`
|
||||
method. See the [Configuration Extraction example](#example-configuration-extraction) in Step 5.
|
||||
|
||||
**Optional optimization**: Implement `lift_constants()` to explicitly request constant lifting for
|
||||
specific inputs before `extract_config()` is called. The pipeline handles this automatically during
|
||||
post-processing.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### ONNX-IR Pipeline
|
||||
|
||||
The `onnx-ir` crate converts ONNX models to an Intermediate Representation through a 5-phase
|
||||
pipeline:
|
||||
|
||||
#### Phase 1: Initialization
|
||||
|
||||
- Creates `GraphState` from ONNX proto structures
|
||||
- **Constants-first approach**: Converts all ONNX initializers into Constant nodes, providing a
|
||||
uniform starting point for processing
|
||||
- Sets up the value store for tensor data using `burn_tensor::TensorData`
|
||||
- Preserves sanitized graph input names for debugging
|
||||
|
||||
#### Phase 2: Node Conversion
|
||||
|
||||
- Converts ONNX nodes to IR nodes using registered processors
|
||||
- Creates `RawNode` instances from ONNX proto nodes (intermediate representation)
|
||||
- Processors extract configuration and construct typed `Node` enum variants
|
||||
- Handles constant nodes specially (extracting values from attributes into tensor store)
|
||||
- Each processor is responsible for its own type inference and node construction
|
||||
|
||||
#### Phase 3: Type Inference
|
||||
|
||||
- Type inference happens within each processor's `process()` method during Phase 2
|
||||
- Processors infer output types based on input types and configuration
|
||||
- Multi-pass processing handles dependencies between nodes
|
||||
- The pipeline may need multiple iterations for complex type dependencies (e.g., control flow)
|
||||
|
||||
#### Phase 4: Post-processing
|
||||
|
||||
- Lifts constants: Makes constant values accessible on downstream node inputs
|
||||
- Eliminates Identity nodes: Removes no-op nodes and rewires the graph
|
||||
- Re-runs constant lifting after Identity elimination
|
||||
|
||||
#### Phase 5: Finalization
|
||||
|
||||
- Removes unreferenced constant nodes
|
||||
- Constructs the final `OnnxGraph` with inputs, outputs, and nodes
|
||||
|
||||
### NodeProcessor Trait
|
||||
|
||||
The `NodeProcessor` trait (defined in `crates/onnx-ir/src/processor.rs`) is the core abstraction for
|
||||
handling ONNX operations. Each processor implements:
|
||||
|
||||
**Required:**
|
||||
|
||||
- `type Config` - Associated type for configuration (use `()` if no config needed)
|
||||
- `infer_types()` - Infer output types from inputs and configuration
|
||||
- `build_node()` - Construct the final `Node` enum variant
|
||||
|
||||
**Optional (have defaults):**
|
||||
|
||||
- `spec()` - Define opset requirements and input/output count validation (`NodeSpec`, `InputSpec`,
|
||||
`OutputSpec`)
|
||||
- `extract_config()` - Extract configuration from attributes/inputs (default returns
|
||||
`Default::default()`)
|
||||
- `lift_constants()` - Request constant lifting for specific inputs (default does nothing)
|
||||
- `input_preferences()` - Declare preferred input types from producers (default returns `None`)
|
||||
|
||||
Design principles: Each processor is self-contained, handling type inference, config extraction, and
|
||||
node construction. Processors return strongly-typed `Node` enum variants, ensuring type safety
|
||||
throughout the pipeline.
|
||||
|
||||
## Testing
|
||||
|
||||
When implementing a new operator, there are several levels of testing to consider:
|
||||
|
||||
### Unit Testing
|
||||
|
||||
- **Processor Methods**: Write unit tests in `crates/onnx-ir/src/node/<operation_name>.rs` to
|
||||
verify:
|
||||
- `extract_config()` - Correctly extracts configuration from attributes and inputs
|
||||
- `infer_types()` - Correctly infers output types (element type, rank, static shapes)
|
||||
- `build_node()` - Constructs correct `Node` enum variant
|
||||
- `spec()` - Defines correct opset and input/output requirements
|
||||
- Error handling for invalid inputs or configurations
|
||||
|
||||
See existing tests in `crates/onnx-ir/src/node/squeeze.rs` for examples.
|
||||
|
||||
- **Code Generation**: Test the burn-onnx Node implementation to verify correct Rust code
|
||||
generation. Each node file typically includes unit tests using `assert_tokens()` to validate
|
||||
generated code against expected output.
|
||||
|
||||
### Integration Testing
|
||||
|
||||
- **Test Path**: Write integration tests in `crates/burn-onnx/onnx-tests/tests/<op_name>/mod.rs` where `<op_name>` is the name of the new operator.
|
||||
|
||||
- **What to Test**:
|
||||
- Create ONNX models that use your operator and test the end-to-end conversion process
|
||||
- Ensure the generated Rust code compiles
|
||||
- Test with realistic ONNX models that use your operator in conjunction with others
|
||||
- Include models that test edge cases (e.g., different input shapes, parameter combinations)
|
||||
- Verify that inputs and outputs match between the original ONNX model and the converted Burn model
|
||||
- Further details can be found in the
|
||||
[onnx-tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/onnx-tests/README.md).
|
||||
|
||||
Testing the processor implementation is particularly important as it directly affects the
|
||||
correctness of the conversion process. Incorrect type inference can lead to mismatched tensor shapes
|
||||
or wrong element types, while incorrect configuration extraction can cause runtime errors or produce
|
||||
incorrect results.
|
||||
|
||||
## Node Enum Architecture
|
||||
|
||||
The ONNX-IR uses an enum-based node representation where each ONNX operation is a variant of the
|
||||
`Node` enum (defined in `crates/onnx-ir/src/ir/node.rs`). Each variant wraps an operation-specific
|
||||
node struct (e.g., `SoftmaxNode`, `Conv2dNode`) that contains `name`, `inputs`, `outputs`, and
|
||||
optionally a `config` field.
|
||||
|
||||
The `define_node_enum!` macro generates both enums from a single source using the syntax
|
||||
`VariantName => module::NodeStructType`:
|
||||
|
||||
```rust
|
||||
define_node_enum! {
|
||||
Softmax => softmax::SoftmaxNode,
|
||||
Conv2d => conv2d::Conv2dNode,
|
||||
Squeeze => squeeze::SqueezeNode,
|
||||
// ... 200+ more variants
|
||||
}
|
||||
```
|
||||
|
||||
This macro generates:
|
||||
|
||||
1. **`NodeType` enum**: Simple unit variants for ONNX parsing (`Softmax`, `Conv2d`, etc.)
|
||||
2. **`Node` enum**: Tuple variants wrapping node structs (`Softmax(SoftmaxNode)`,
|
||||
`Conv2d(Conv2dNode)`, etc.)
|
||||
3. **Accessor methods**: `name()`, `inputs()`, `outputs()` automatically generated for the `Node`
|
||||
enum
|
||||
|
||||
This design provides:
|
||||
|
||||
- **Type safety**: Each operation has its own struct type
|
||||
- **Trait implementations**: Operations can implement specific traits on their node structs
|
||||
- **Single source of truth**: Both enums are guaranteed to stay in sync
|
||||
- **Pattern matching**: Easy to match on specific operations and access their configuration
|
||||
|
||||
## Resources
|
||||
|
||||
1. [PyTorch to ONNX](https://pytorch.org/docs/stable/onnx.html)
|
||||
2. [ONNX to PyTorch](https://github.com/ENOT-AutoDL/onnx2torch)
|
||||
3. [ONNX Introduction](https://onnx.ai/onnx/intro/)
|
||||
4. [ONNX Operators](https://onnx.ai/onnx/operators/index.html)
|
||||
5. [ONNX Protos](https://onnx.ai/onnx/api/classes.html)
|
||||
6. [ONNX Optimizer](https://github.com/onnx/optimizer)
|
||||
7. [Netron](https://github.com/lutzroeder/netron)
|
||||
@@ -1,48 +0,0 @@
|
||||
[package]
|
||||
authors = [
|
||||
"Dilshod Tadjibaev (@antimora)",
|
||||
"Nathaniel Simard (@nathanielsimard)",
|
||||
]
|
||||
description = "Library for importing datamodels into the Burn framework"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "burn-import"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-import"
|
||||
documentation = "https://docs.rs/burn-import"
|
||||
version.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["pytorch", "safetensors", "onnx", "burn-onnx?/default"]
|
||||
|
||||
onnx = ["burn-onnx"]
|
||||
onnx-mmap = ["burn-onnx", "burn-onnx/mmap"]
|
||||
pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]
|
||||
safetensors = [
|
||||
"burn/record-item-custom-serde",
|
||||
"thiserror",
|
||||
"zip",
|
||||
"candle-core",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../burn", version = "=0.21.0", default-features = false, features = [
|
||||
"std",
|
||||
] }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0", default-features = false, optional = true }
|
||||
burn-store = { path = "../burn-store", version = "=0.21.0", default-features = false, features = ["std", "pytorch", "burnpack"] }
|
||||
burn-onnx = { path = "../burn-onnx", version = "=0.21.0", default-features = false, optional = true }
|
||||
candle-core = { workspace = true, optional = true }
|
||||
derive-new = { workspace = true }
|
||||
regex = { workspace = true, features = ["default"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true, features = ["std"] }
|
||||
thiserror = { workspace = true, optional = true }
|
||||
zip = { workspace = true, optional = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["default"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -1,24 +0,0 @@
|
||||
# Burn Import
|
||||
|
||||
The `burn-import` crate enables seamless integration of pre-trained models from popular machine
|
||||
learning frameworks into the Burn ecosystem. This functionality allows you to leverage existing
|
||||
models while benefiting from Burn's performance optimizations and native Rust integration.
|
||||
|
||||
## Supported Import Formats
|
||||
|
||||
Burn currently supports three primary model import formats, each serving different use cases:
|
||||
|
||||
| Format | Description | Use Case |
|
||||
| ----------------------------------------------------------------------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| [**ONNX** (Guide)](https://burn.dev/books/burn/import/onnx-model.html) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export |
|
||||
| [**PyTorch** (Guide)](https://burn.dev/books/burn/import/pytorch-model.html) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture |
|
||||
| [**Safetensors** (Guide)](https://burn.dev/books/burn/import/safetensors-model.html) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture |
|
||||
|
||||
## ONNX Contributor Resources
|
||||
|
||||
- [ONNX to Burn conversion guide](https://burn.dev/books/contributor/guides/onnx-to-burn-conversion-tool.html) -
|
||||
Instructions for adding support for additional ONNX operators
|
||||
- [ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md) -
|
||||
Testing procedures for ONNX operators
|
||||
- [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) -
|
||||
Complete list of currently supported ONNX operators
|
||||
@@ -1,32 +0,0 @@
|
||||
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(unused)]
|
||||
pub struct Net<B: Backend> {
|
||||
do_not_exist_in_pt: Conv2d<B>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::PyTorchFileRecorder;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[should_panic(
|
||||
expected = "Missing source values for the 'do_not_exist_in_pt' field of type 'Conv2dRecordItem'. Please verify the source data and ensure the field name is correct"
|
||||
)]
|
||||
fn should_fail_if_struct_field_is_missing() {
|
||||
let device = Default::default();
|
||||
let _record: NetRecord<TestBackend> =
|
||||
PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(
|
||||
"tests/missing_module_field/missing_module_field.pt".into(),
|
||||
&device,
|
||||
)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
#[allow(unused)]
|
||||
pub struct Net<B: Backend> {
|
||||
conv1: Conv2d<B>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::backend::TestBackend;
|
||||
|
||||
use burn::record::{FullPrecisionSettings, Recorder};
|
||||
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn should_fail_if_not_found() {
|
||||
let device = Default::default();
|
||||
let _record: NetRecord<TestBackend> =
|
||||
PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load("tests/top_level_key/top_level_key.pt".into(), &device)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_load() {
|
||||
let device = Default::default();
|
||||
let load_args = LoadArgs::new("tests/top_level_key/top_level_key.pt".into())
|
||||
.with_top_level_key("my_state_dict");
|
||||
|
||||
let _record: NetRecord<TestBackend> =
|
||||
PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
.load(load_args, &device)
|
||||
.expect("Should decode state successfully");
|
||||
}
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
use burn::{
|
||||
module::Param,
|
||||
record::{PrecisionSettings, Record},
|
||||
tensor::{Tensor, backend::Backend},
|
||||
};
|
||||
|
||||
use burn::record::serde::{
|
||||
adapter::{BurnModuleAdapter, DefaultAdapter},
|
||||
data::NestedValue,
|
||||
ser::Serializer,
|
||||
};
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
/// A PyTorch adapter for the Burn module used during deserialization.
|
||||
///
|
||||
/// Not all Burn module correspond to a PyTorch module. Therefore,
|
||||
/// we need to adapt the Burn module to a PyTorch module. We implement
|
||||
/// only those that differ.
|
||||
pub struct PyTorchAdapter<PS: PrecisionSettings, B: Backend> {
|
||||
_precision_settings: std::marker::PhantomData<(PS, B)>,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings, B: Backend> BurnModuleAdapter for PyTorchAdapter<PS, B> {
|
||||
fn adapt_linear(data: NestedValue) -> NestedValue {
|
||||
// Get the current module in the form of map.
|
||||
let mut map = data.as_map().expect("Failed to get map from NestedValue");
|
||||
|
||||
// Get/remove the weight parameter.
|
||||
let weight = map
|
||||
.remove("weight")
|
||||
.expect("Failed to find 'weight' key in map");
|
||||
|
||||
// Convert the weight parameter to a tensor (use default device, since it's quick operation).
|
||||
let weight: Param<Tensor<B, 2>> = weight
|
||||
.try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default())
|
||||
.expect("Failed to deserialize weight");
|
||||
|
||||
// Do not capture transpose op when using autodiff backend
|
||||
let weight = weight.set_require_grad(false);
|
||||
// Transpose the weight tensor.
|
||||
let weight_transposed = Param::from_tensor(weight.val().transpose());
|
||||
|
||||
// Insert the transposed weight tensor back into the map.
|
||||
map.insert(
|
||||
"weight".to_owned(),
|
||||
serialize::<PS, _, 2>(weight_transposed),
|
||||
);
|
||||
|
||||
// Return the modified map.
|
||||
NestedValue::Map(map)
|
||||
}
|
||||
|
||||
fn adapt_group_norm(data: NestedValue) -> NestedValue {
|
||||
rename_weight_bias(data)
|
||||
}
|
||||
|
||||
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
|
||||
rename_weight_bias(data)
|
||||
}
|
||||
|
||||
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
|
||||
rename_weight_bias(data)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to serialize a param tensor.
|
||||
fn serialize<PS, B, const D: usize>(val: Param<Tensor<B, D>>) -> NestedValue
|
||||
where
|
||||
B: Backend,
|
||||
PS: PrecisionSettings,
|
||||
{
|
||||
let serializer = Serializer::new();
|
||||
|
||||
val.into_item::<PS>()
|
||||
.serialize(serializer)
|
||||
.expect("Failed to serialize the item")
|
||||
}
|
||||
|
||||
/// Helper function to rename the weight and bias parameters to gamma and beta.
|
||||
///
|
||||
/// This is needed because PyTorch uses different names for the normalizer parameter
|
||||
/// than Burn. Burn uses gamma and beta, while PyTorch uses weight and bias.
|
||||
fn rename_weight_bias(data: NestedValue) -> NestedValue {
|
||||
// Get the current module in the form of map.
|
||||
let mut map = data.as_map().expect("Failed to get map from NestedValue");
|
||||
|
||||
// Rename the weight parameter to gamma.
|
||||
let weight = map
|
||||
.remove("weight")
|
||||
.expect("Failed to find 'weight' key in map");
|
||||
|
||||
map.insert("gamma".to_owned(), weight);
|
||||
|
||||
// Rename the bias parameter to beta.
|
||||
let bias = map
|
||||
.remove("bias")
|
||||
.expect("Failed to find 'bias' key in map");
|
||||
|
||||
map.insert("beta".to_owned(), bias);
|
||||
|
||||
// Return the modified map.
|
||||
NestedValue::Map(map)
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
use core::ops::Deref;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use burn::record::serde::{
|
||||
data::{NestedValue, Serializable},
|
||||
error,
|
||||
ser::Serializer,
|
||||
};
|
||||
use burn::{
|
||||
module::ParamId,
|
||||
record::PrecisionSettings,
|
||||
tensor::{Element, ElementConversion, TensorData, bf16, f16},
|
||||
};
|
||||
|
||||
use candle_core::WithDType;
|
||||
use serde::Serialize;
|
||||
|
||||
use burn::record::RecorderError;
|
||||
use zip::result::ZipError;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Serde error: {0}")]
|
||||
Serde(#[from] error::Error),
|
||||
|
||||
#[error("Candle Tensor error: {0}")]
|
||||
CandleTensor(#[from] candle_core::Error),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Zip error: {0}")]
|
||||
Zip(#[from] ZipError),
|
||||
|
||||
// Add other kinds of errors as needed
|
||||
#[error("other error: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
// Implement From trait for Error to RecorderError
|
||||
impl From<Error> for RecorderError {
|
||||
fn from(error: Error) -> Self {
|
||||
RecorderError::DeserializeError(error.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes a candle tensor.
|
||||
///
|
||||
/// Tensors are wrapped in a `Param` struct (learnable parameters) and serialized as a `TensorData` struct.
|
||||
///
|
||||
/// Values are serialized as `FloatElem` or `IntElem` depending on the precision settings.
|
||||
impl Serializable for CandleTensor {
|
||||
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, error::Error>
|
||||
where
|
||||
PS: PrecisionSettings,
|
||||
{
|
||||
let shape = self.shape().clone().into_dims();
|
||||
let flatten = CandleTensor(self.flatten_all().expect("Failed to flatten the tensor"));
|
||||
let param_id = ParamId::new();
|
||||
|
||||
match self.dtype() {
|
||||
candle_core::DType::U8 => {
|
||||
serialize_data::<u8, PS::IntElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::U32 => {
|
||||
serialize_data::<u32, PS::IntElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::I64 => {
|
||||
serialize_data::<i64, PS::IntElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::BF16 => {
|
||||
serialize_data::<bf16, PS::FloatElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::F16 => {
|
||||
serialize_data::<f16, PS::FloatElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::F32 => {
|
||||
serialize_data::<f32, PS::FloatElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
candle_core::DType::F64 => {
|
||||
serialize_data::<f64, PS::FloatElem>(flatten, shape, param_id, serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to serialize a candle tensor data.
|
||||
fn serialize_data<T, E>(
|
||||
tensor: CandleTensor,
|
||||
shape: Vec<usize>,
|
||||
param_id: ParamId,
|
||||
serializer: Serializer,
|
||||
) -> Result<NestedValue, error::Error>
|
||||
where
|
||||
E: Element + Serialize,
|
||||
T: WithDType + ElementConversion,
|
||||
{
|
||||
let data: Vec<E> = tensor
|
||||
.to_vec1::<T>()
|
||||
.map_err(|err| error::Error::Other(format!("Candle to vec1 error: {err}")))?
|
||||
.into_iter()
|
||||
.map(ElementConversion::elem)
|
||||
.collect();
|
||||
|
||||
let data = TensorData::new(data, shape.clone());
|
||||
let (dtype, bytes) = (data.dtype, data.into_bytes());
|
||||
|
||||
// Manually serialize the tensor instead of using the `ParamSerde` struct, such as:
|
||||
// ParamSerde::new(param_id, TensorData::new(data, shape)).serialize(serializer)
|
||||
// Because serializer copies individual elements of TensorData `value` into a new Vec<u8>,
|
||||
// which is not necessary and inefficient.
|
||||
let mut tensor_data: HashMap<String, NestedValue> = HashMap::new();
|
||||
tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes));
|
||||
tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?);
|
||||
tensor_data.insert("dtype".into(), dtype.serialize(serializer)?);
|
||||
|
||||
let mut param: HashMap<String, NestedValue> = HashMap::new();
|
||||
param.insert("id".into(), NestedValue::String(param_id.serialize()));
|
||||
param.insert("param".into(), NestedValue::Map(tensor_data));
|
||||
|
||||
Ok(NestedValue::Map(param))
|
||||
}
|
||||
|
||||
/// New type struct for Candle tensors because we need to implement the `Serializable` trait for it.
|
||||
pub struct CandleTensor(pub candle_core::Tensor);
|
||||
|
||||
impl Deref for CandleTensor {
|
||||
type Target = candle_core::Tensor;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_debug_info(
|
||||
tensors: &HashMap<String, CandleTensor>,
|
||||
remapped_keys: Vec<(String, String)>,
|
||||
) {
|
||||
let mut remapped_keys = remapped_keys;
|
||||
remapped_keys.sort();
|
||||
println!("Debug information of keys and tensor shapes:\n---");
|
||||
for (new_key, old_key) in remapped_keys {
|
||||
if old_key != new_key {
|
||||
println!("Original Key: {old_key}");
|
||||
println!("Remapped Key: {new_key}");
|
||||
} else {
|
||||
println!("Key: {new_key}");
|
||||
}
|
||||
|
||||
let shape = tensors[&new_key].shape();
|
||||
let dtype = tensors[&new_key].dtype();
|
||||
println!("Shape: {shape:?}");
|
||||
println!("Dtype: {dtype:?}");
|
||||
println!("---");
|
||||
}
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
pub mod adapter;
|
||||
#[cfg(feature = "safetensors")]
|
||||
pub mod candle;
|
||||
pub mod tensor_snapshot;
|
||||
@@ -1,81 +0,0 @@
|
||||
//! TensorSnapshot support for burn-import.
|
||||
|
||||
use burn::record::PrecisionSettings;
|
||||
use burn::record::serde::{
|
||||
data::{NestedValue, Serializable},
|
||||
error,
|
||||
ser::Serializer,
|
||||
};
|
||||
use burn_store::TensorSnapshot;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Deref;
|
||||
|
||||
/// Wrapper for TensorSnapshot to implement Serializable
|
||||
pub struct TensorSnapshotWrapper(pub TensorSnapshot);
|
||||
|
||||
impl Deref for TensorSnapshotWrapper {
|
||||
type Target = TensorSnapshot;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes a TensorSnapshot.
|
||||
///
|
||||
/// Tensors are wrapped in a `Param` struct (learnable parameters) and serialized as a `TensorData` struct.
|
||||
impl Serializable for TensorSnapshotWrapper {
|
||||
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, error::Error>
|
||||
where
|
||||
PS: PrecisionSettings,
|
||||
{
|
||||
// Get the tensor data
|
||||
let data = self
|
||||
.0
|
||||
.to_data()
|
||||
.map_err(|e| error::Error::Other(format!("Failed to get tensor data: {:?}", e)))?;
|
||||
let shape = data.shape.clone();
|
||||
let dtype = data.dtype;
|
||||
let bytes = data.into_bytes();
|
||||
|
||||
// Create the tensor data structure
|
||||
let mut tensor_data: HashMap<String, NestedValue> = HashMap::new();
|
||||
tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes));
|
||||
tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?);
|
||||
tensor_data.insert("dtype".into(), dtype.serialize(serializer)?);
|
||||
|
||||
// Create the param structure
|
||||
let param_id = self.0.tensor_id.unwrap_or_default();
|
||||
let mut param: HashMap<String, NestedValue> = HashMap::new();
|
||||
param.insert("id".into(), NestedValue::String(param_id.serialize()));
|
||||
param.insert("param".into(), NestedValue::Map(tensor_data));
|
||||
|
||||
Ok(NestedValue::Map(param))
|
||||
}
|
||||
}
|
||||
|
||||
/// Print debug information about tensors
|
||||
pub fn print_debug_info(
|
||||
tensors: &HashMap<String, TensorSnapshotWrapper>,
|
||||
remapped_keys: Vec<(String, String)>,
|
||||
) {
|
||||
let mut remapped_keys = remapped_keys;
|
||||
remapped_keys.sort();
|
||||
println!("Debug information of keys and tensor shapes:\n---");
|
||||
for (new_key, old_key) in remapped_keys {
|
||||
if old_key != new_key {
|
||||
println!("Original Key: {old_key}");
|
||||
println!("Remapped Key: {new_key}");
|
||||
} else {
|
||||
println!("Key: {new_key}");
|
||||
}
|
||||
|
||||
let snapshot = &tensors[&new_key].0;
|
||||
let shape = &snapshot.shape;
|
||||
let dtype = &snapshot.dtype;
|
||||
println!("Shape: {shape:?}");
|
||||
println!("Dtype: {dtype:?}");
|
||||
println!("---");
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! `burn-import` is a crate designed to simplify the process of importing models trained in other
|
||||
//! machine learning frameworks into the Burn framework.
|
||||
|
||||
#[cfg(any(feature = "pytorch", feature = "safetensors"))]
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// The onnx module.
|
||||
#[cfg(feature = "onnx")]
|
||||
#[deprecated(
|
||||
since = "0.21.0",
|
||||
note = "ONNX import was moved to `burn-onnx`. Use that crate instead."
|
||||
)]
|
||||
pub mod onnx {
|
||||
#[deprecated(
|
||||
since = "0.21.0",
|
||||
note = "ONNX import was moved to `burn-onnx`. Use that crate instead."
|
||||
)]
|
||||
#[allow(missing_docs)]
|
||||
pub type ModelGen = burn_onnx::ModelGen;
|
||||
}
|
||||
|
||||
/// The module for generating the burn code.
|
||||
#[cfg(feature = "onnx")]
|
||||
#[deprecated(
|
||||
since = "0.21.0",
|
||||
note = "ONNX import was moved to `burn-onnx`. Use that crate instead."
|
||||
)]
|
||||
pub mod burn {
|
||||
pub use burn_onnx::burn::*;
|
||||
}
|
||||
|
||||
/// The PyTorch module for recorder.
|
||||
#[cfg(feature = "pytorch")]
|
||||
pub mod pytorch;
|
||||
|
||||
/// The Safetensors module for recorder.
|
||||
#[cfg(feature = "safetensors")]
|
||||
pub mod safetensors;
|
||||
|
||||
// Enabled when the `pytorch` or `safetensors` feature is enabled.
|
||||
#[cfg(any(feature = "pytorch", feature = "safetensors"))]
|
||||
mod common;
|
||||
@@ -1,51 +0,0 @@
|
||||
use std::path::Path;
|
||||
|
||||
use burn_store::pytorch::PytorchReader;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use super::reader::Error;
|
||||
|
||||
/// Loads configuration data from a PyTorch `.pth` file.
|
||||
///
|
||||
/// This function reads specific configuration or metadata stored in PyTorch checkpoint files.
|
||||
/// It's particularly useful for extracting model configurations that might be saved alongside
|
||||
/// the model weights.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - Path to the PyTorch `.pth` file.
|
||||
/// * `key` - Optional key to filter specific data within the pickle file.
|
||||
/// If `None`, the entire content is deserialized.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `D` - The target type to deserialize into. Must implement `DeserializeOwned`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `Result` containing the deserialized configuration data, or an `Error` if
|
||||
/// reading or deserialization fails.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use burn_import::pytorch::config::load_config_from_file;
|
||||
/// use serde::Deserialize;
|
||||
///
|
||||
/// #[derive(Debug, Deserialize)]
|
||||
/// struct ModelConfig {
|
||||
/// hidden_size: usize,
|
||||
/// num_layers: usize,
|
||||
/// // ... other configuration fields
|
||||
/// }
|
||||
///
|
||||
/// let config: ModelConfig = load_config_from_file("model.pth", Some("config"))?;
|
||||
/// ```
|
||||
pub fn load_config_from_file<D, P>(file: P, key: Option<&str>) -> Result<D, Error>
|
||||
where
|
||||
D: DeserializeOwned,
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
// Use burn-store's PytorchReader to load and deserialize config
|
||||
PytorchReader::load_config(file, key).map_err(Error::Store)
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
mod config;
|
||||
mod reader;
|
||||
mod recorder;
|
||||
pub use config::load_config_from_file;
|
||||
pub use recorder::{LoadArgs, PyTorchFileRecorder};
|
||||
@@ -1,110 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::common::{
|
||||
adapter::PyTorchAdapter,
|
||||
tensor_snapshot::{TensorSnapshotWrapper, print_debug_info},
|
||||
};
|
||||
|
||||
use burn::record::PrecisionSettings;
|
||||
use burn::{
|
||||
record::serde::{
|
||||
data::{remap, unflatten},
|
||||
de::Deserializer,
|
||||
},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
use burn_store::pytorch::PytorchReader;
|
||||
use regex::Regex;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
/// Error type for PyTorch file operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("Store error: {0}")]
|
||||
Store(#[from] burn_store::pytorch::PytorchError),
|
||||
|
||||
#[error("Serde error: {0}")]
|
||||
Serde(#[from] burn::record::serde::error::Error),
|
||||
|
||||
#[error("Other error: {0}")]
|
||||
Other(String),
|
||||
}
|
||||
|
||||
/// Deserializes tensor data from a PyTorch file (`.pt` or `.pth`) into a Burn record.
|
||||
///
|
||||
/// This function reads tensors from a pickle file using burn-store's PyTorch reader,
|
||||
/// optionally remaps their keys, and then deserializes them into the specified record type `D`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - The path to the PyTorch file to load.
|
||||
/// * `key_remap` - A list of rules for renaming tensor keys. Each rule is a tuple
|
||||
/// containing a regular expression to match the original key and a replacement string.
|
||||
/// * `top_level_key` - An optional key within the pickle file if the tensors are nested
|
||||
/// under a specific dictionary key (e.g., "state_dict").
|
||||
/// * `debug` - If `true`, prints information about the loaded tensors and remapped keys.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `PS` - The precision settings to use during deserialization.
|
||||
/// * `D` - The target Burn record type to deserialize into.
|
||||
/// * `B` - The backend to use for tensor operations (primarily for type context).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `Result` containing the deserialized record `D` on success, or an `Error` if
|
||||
/// reading, remapping, or deserialization fails.
|
||||
pub fn from_file<PS, D, B>(
|
||||
path: &Path,
|
||||
key_remap: Vec<(Regex, String)>,
|
||||
top_level_key: Option<&str>,
|
||||
debug: bool,
|
||||
) -> Result<D, Error>
|
||||
where
|
||||
D: DeserializeOwned,
|
||||
PS: PrecisionSettings,
|
||||
B: Backend,
|
||||
{
|
||||
// Use burn-store's PyTorch reader to load tensors
|
||||
let reader = if let Some(key) = top_level_key {
|
||||
PytorchReader::with_top_level_key(path, key)?
|
||||
} else {
|
||||
PytorchReader::new(path)?
|
||||
};
|
||||
|
||||
// Get the tensors as TensorSnapshots and wrap them
|
||||
let tensors: HashMap<String, TensorSnapshotWrapper> = reader
|
||||
.into_tensors()
|
||||
.into_iter()
|
||||
.map(|(key, snapshot)| (key, TensorSnapshotWrapper(snapshot)))
|
||||
.collect();
|
||||
|
||||
// Remap the tensor keys based on the provided rules
|
||||
let (tensors, remapped_keys) = remap(tensors, key_remap);
|
||||
|
||||
// Print debug information if enabled
|
||||
if debug {
|
||||
print_debug_info(&tensors, remapped_keys);
|
||||
}
|
||||
|
||||
// Convert the flat map of tensors into a nested data structure suitable for deserialization
|
||||
let nested_value = unflatten::<PS, _>(tensors)?;
|
||||
|
||||
// Create a deserializer using the PyTorch adapter and the nested tensor data
|
||||
let deserializer = Deserializer::<PyTorchAdapter<PS, B>>::new(nested_value, true);
|
||||
|
||||
// Deserialize the nested data structure into the target record type
|
||||
let value = D::deserialize(deserializer)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
// Re-export burn-store's PyTorch reader types for convenience
|
||||
|
||||
// Implement conversion to RecorderError for compatibility with the Recorder trait
|
||||
impl From<Error> for burn::record::RecorderError {
|
||||
fn from(error: Error) -> Self {
|
||||
burn::record::RecorderError::DeserializeError(error.to_string())
|
||||
}
|
||||
}
|
||||
@@ -1,173 +0,0 @@
|
||||
use core::marker::PhantomData;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use burn::{
|
||||
record::{PrecisionSettings, Record, Recorder, RecorderError},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
use regex::Regex;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
|
||||
use super::reader::from_file;
|
||||
|
||||
/// Recorder for loading PyTorch (`.pt`) files into Burn modules.
|
||||
///
|
||||
/// Load arguments ([`LoadArgs`]) can be used to specify the file path and
|
||||
/// remap parameter keys during loading.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct PyTorchFileRecorder<PS: PrecisionSettings> {
|
||||
_settings: PhantomData<PS>,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS> {
|
||||
type Settings = PS;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = LoadArgs;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
_item: I,
|
||||
_file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
unimplemented!("Save operations are not supported by PyTorchFileRecorder.")
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
_file: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
unimplemented!("load_item is not implemented for PyTorchFileRecorder; use load instead.")
|
||||
}
|
||||
|
||||
fn load<R: Record<B>>(
|
||||
&self,
|
||||
args: Self::LoadArgs,
|
||||
device: &B::Device,
|
||||
) -> Result<R, RecorderError> {
|
||||
let item = from_file::<PS, R::Item<Self::Settings>, B>(
|
||||
&args.file,
|
||||
args.key_remap,
|
||||
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
|
||||
args.debug,
|
||||
)?;
|
||||
Ok(R::from_item(item, device))
|
||||
}
|
||||
}
|
||||
|
||||
/// Arguments for loading PyTorch model weights.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Parameter keys within a PyTorch file (`.pt` extension) can be inspected using
|
||||
/// tools like [Netron](https://github.com/lutzroeder/netron).
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
|
||||
/// use burn::record::{FullPrecisionSettings, Recorder};
|
||||
///
|
||||
/// // Create load arguments, specifying the file and a key remapping rule.
|
||||
/// let args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
|
||||
/// // Remove "conv." prefix, e.g., "conv.weight" -> "weight"
|
||||
/// .with_key_remap("conv\\.(.*)", "$1");
|
||||
///
|
||||
/// // Load the record using the default recorder.
|
||||
/// let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
|
||||
/// .load(args, &burn::backend::NdArray::default().device()) // Provide a device
|
||||
/// .expect("Failed to decode state from file"); // Example assertion
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoadArgs {
|
||||
/// The path to the PyTorch file (`.pt`).
|
||||
pub file: PathBuf,
|
||||
|
||||
/// A list of key remapping rules applied to the state dictionary keys.
|
||||
/// Each rule consists of a regular expression and a replacement string.
|
||||
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
|
||||
/// for more details.
|
||||
pub key_remap: Vec<(Regex, String)>,
|
||||
|
||||
/// Optional top-level key under which the state dictionary is nested within the file.
|
||||
/// If `None`, the root object is assumed to be the state dictionary.
|
||||
pub top_level_key: Option<String>,
|
||||
|
||||
/// If `true`, prints debug information during the loading process.
|
||||
pub debug: bool,
|
||||
}
|
||||
|
||||
impl LoadArgs {
|
||||
/// Creates new load arguments with the given file path.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - The path to the PyTorch file to load.
|
||||
pub fn new(file: PathBuf) -> Self {
|
||||
Self {
|
||||
file,
|
||||
key_remap: Vec::new(),
|
||||
top_level_key: None,
|
||||
debug: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a key remapping rule.
|
||||
///
|
||||
/// Keys from the PyTorch state dictionary are modified if they match the pattern.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pattern` - The regular expression pattern to match against state dictionary keys.
|
||||
/// * `replacement` - The replacement string. Capture groups can be used (e.g., `$1`).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the provided `pattern` is an invalid regular expression.
|
||||
///
|
||||
/// See the [regex crate documentation](https://docs.rs/regex/latest/regex/) for pattern syntax
|
||||
/// and [replacement string syntax](https://docs.rs/regex/latest/regex/struct.Regex.html#replacement-string-syntax).
|
||||
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
|
||||
let regex = Regex::new(pattern).expect("Invalid regex pattern provided to with_key_remap");
|
||||
self.key_remap.push((regex, replacement.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Specifies a top-level key in the file under which the state dictionary is nested.
|
||||
///
|
||||
/// Some PyTorch files store the state dictionary within a larger structure (e.g., a dictionary).
|
||||
/// Use this method if the weights are not at the root level of the file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - The top-level key to access the state dictionary.
|
||||
pub fn with_top_level_key(mut self, key: &str) -> Self {
|
||||
self.top_level_key = Some(key.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Enables printing of debug information during loading.
|
||||
pub fn with_debug_print(mut self) -> Self {
|
||||
self.debug = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for LoadArgs {
|
||||
fn from(val: PathBuf) -> Self {
|
||||
LoadArgs::new(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LoadArgs {
|
||||
fn from(val: String) -> Self {
|
||||
LoadArgs::new(val.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for LoadArgs {
|
||||
fn from(val: &str) -> Self {
|
||||
LoadArgs::new(val.into())
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
mod reader;
|
||||
mod recorder;
|
||||
pub use recorder::{AdapterType, LoadArgs, SafetensorsFileRecorder};
|
||||
@@ -1,72 +0,0 @@
|
||||
use std::{collections::HashMap, path::Path};
|
||||
|
||||
use burn::{
|
||||
record::{
|
||||
PrecisionSettings,
|
||||
serde::{
|
||||
adapter::DefaultAdapter,
|
||||
data::{remap, unflatten},
|
||||
de::Deserializer,
|
||||
},
|
||||
},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
use candle_core::{Device, safetensors};
|
||||
use regex::Regex;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use super::super::common::adapter::PyTorchAdapter;
|
||||
use super::recorder::AdapterType;
|
||||
use crate::common::candle::{CandleTensor, Error, print_debug_info};
|
||||
|
||||
/// Deserializes model state from a safetensors file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the safetensors file.
|
||||
/// * `key_remap` - A vector of tuples containing regular expressions and replacement strings
|
||||
/// for remapping tensor keys.
|
||||
/// * `debug` - If true, prints debug information about the loaded tensors and remapped keys.
|
||||
/// * `adapter_type` - Specifies the adapter to use for deserialization (e.g., PyTorch, None).
|
||||
pub fn from_file<PS, D, B>(
|
||||
path: &Path,
|
||||
key_remap: Vec<(Regex, String)>,
|
||||
debug: bool,
|
||||
adapter_type: AdapterType,
|
||||
) -> Result<D, Error>
|
||||
where
|
||||
D: DeserializeOwned,
|
||||
PS: PrecisionSettings,
|
||||
B: Backend,
|
||||
{
|
||||
// Load tensors from the safetensors file into a HashMap.
|
||||
let tensors: HashMap<String, CandleTensor> = safetensors::load(path, &Device::Cpu)?
|
||||
.into_iter()
|
||||
.map(|(key, tensor)| (key, CandleTensor(tensor)))
|
||||
.collect();
|
||||
|
||||
// Remap tensor keys based on the provided patterns.
|
||||
let (tensors, remapped_keys) = remap(tensors, key_remap);
|
||||
|
||||
// Optionally print debug information about tensors and key remapping.
|
||||
if debug {
|
||||
print_debug_info(&tensors, remapped_keys);
|
||||
}
|
||||
|
||||
// Convert the flat map of tensors into a nested data structure suitable for deserialization.
|
||||
let nested_value = unflatten::<PS, _>(tensors)?;
|
||||
|
||||
// Deserialize the nested data structure into the target type using the specified adapter.
|
||||
let value = match adapter_type {
|
||||
AdapterType::PyTorch => D::deserialize(Deserializer::<PyTorchAdapter<PS, B>>::new(
|
||||
nested_value,
|
||||
true, // Allow unexpected fields by default? Might need clarification.
|
||||
))?,
|
||||
AdapterType::NoAdapter => {
|
||||
D::deserialize(Deserializer::<DefaultAdapter>::new(nested_value, true))?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
@@ -1,187 +0,0 @@
|
||||
use core::marker::PhantomData;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use burn::{
|
||||
record::{PrecisionSettings, Record, Recorder, RecorderError},
|
||||
tensor::backend::Backend,
|
||||
};
|
||||
|
||||
use regex::Regex;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
|
||||
use super::reader::from_file;
|
||||
|
||||
/// Recorder for loading HuggingFace Safetensors files (`.safetensors`) into Burn modules.
|
||||
///
|
||||
/// This recorder uses [LoadArgs] to configure loading behavior, such as key remapping.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct SafetensorsFileRecorder<PS: PrecisionSettings> {
|
||||
_settings: PhantomData<PS>,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings, B: Backend> Recorder<B> for SafetensorsFileRecorder<PS> {
|
||||
type Settings = PS;
|
||||
type RecordArgs = PathBuf;
|
||||
type RecordOutput = ();
|
||||
type LoadArgs = LoadArgs;
|
||||
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
_item: I,
|
||||
_file: Self::RecordArgs,
|
||||
) -> Result<(), RecorderError> {
|
||||
unimplemented!("save_item not implemented for SafetensorsFileRecorder")
|
||||
}
|
||||
|
||||
fn load_item<I: DeserializeOwned>(
|
||||
&self,
|
||||
_file: &mut Self::LoadArgs,
|
||||
) -> Result<I, RecorderError> {
|
||||
unimplemented!("load_item not implemented for SafetensorsFileRecorder")
|
||||
}
|
||||
|
||||
fn load<R: Record<B>>(
|
||||
&self,
|
||||
args: Self::LoadArgs,
|
||||
device: &B::Device,
|
||||
) -> Result<R, RecorderError> {
|
||||
let item = from_file::<PS, R::Item<Self::Settings>, B>(
|
||||
&args.file,
|
||||
args.key_remap,
|
||||
args.debug,
|
||||
args.adapter_type,
|
||||
)?;
|
||||
Ok(R::from_item(item, device))
|
||||
}
|
||||
}
|
||||
|
||||
/// Arguments for loading a Safetensors file using [SafetensorsFileRecorder].
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use burn_import::safetensors::{AdapterType, LoadArgs, SafetensorsFileRecorder};
|
||||
/// use burn::record::{FullPrecisionSettings, Recorder};
|
||||
/// use std::path::PathBuf;
|
||||
///
|
||||
/// // Dummy model record structure
|
||||
/// #[derive(Record, Default)]
|
||||
/// struct MyModelRecord<B: Backend> {
|
||||
/// // fields matching the tensor names in the file
|
||||
/// }
|
||||
///
|
||||
/// let device = Default::default(); // Replace with your actual device
|
||||
///
|
||||
/// // Example assuming a file named 'model.safetensors' exists
|
||||
/// let args = LoadArgs::new(PathBuf::from("model.safetensors"))
|
||||
/// // Example: Remove "model.encoder." prefix from keys
|
||||
/// .with_key_remap("model\\.encoder\\.(.*)", "$1")
|
||||
/// .with_adapter_type(AdapterType::PyTorch) // Specify if adaptation is needed
|
||||
/// .with_debug_print(); // Enable debug output
|
||||
///
|
||||
/// let record: MyModelRecord<MyBackend> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
|
||||
/// .load(args, &device)
|
||||
/// .expect("Should decode state successfully");
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LoadArgs {
|
||||
/// The path to the Safetensors file to load.
|
||||
pub file: PathBuf,
|
||||
|
||||
/// A list of key remapping rules applied sequentially. Each tuple contains a
|
||||
/// regular expression ([`Regex`]) to match keys and a replacement string.
|
||||
/// See [regex::Regex::replace_all](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace_all)
|
||||
/// for replacement syntax details.
|
||||
pub key_remap: Vec<(Regex, String)>,
|
||||
|
||||
/// If true, prints debug information during the loading process.
|
||||
pub debug: bool,
|
||||
|
||||
/// The type of adapter to apply for potential framework-specific tensor transformations
|
||||
/// (e.g., transposing certain weights).
|
||||
pub adapter_type: AdapterType,
|
||||
}
|
||||
|
||||
/// Specifies the type of adapter to use for tensor loading.
|
||||
///
|
||||
/// Adapters handle potential differences in tensor formats or naming conventions
|
||||
/// between the source framework and Burn.
|
||||
#[derive(Debug, Clone, Default, Copy)]
|
||||
pub enum AdapterType {
|
||||
/// Adapts tensors assuming they originated from PyTorch.
|
||||
#[default]
|
||||
PyTorch,
|
||||
|
||||
/// Loads tensors directly without any specific adaptation.
|
||||
NoAdapter,
|
||||
}
|
||||
|
||||
impl LoadArgs {
|
||||
/// Creates new `LoadArgs` for the given file path.
|
||||
///
|
||||
/// By default, no key remapping is applied, debug printing is off,
|
||||
/// and the adapter type is [AdapterType::PyTorch].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - The path to the Safetensors file.
|
||||
pub fn new(file: PathBuf) -> Self {
|
||||
Self {
|
||||
file,
|
||||
key_remap: Vec::new(),
|
||||
debug: false,
|
||||
adapter_type: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a key remapping rule.
|
||||
///
|
||||
/// Rules are applied in the order they are added.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pattern` - The regular expression pattern to match tensor keys.
|
||||
/// * `replacement` - The replacement string. Capture groups like `$1`, `$2` can be used.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `pattern` is not a valid regular expression.
|
||||
///
|
||||
/// See [Regex syntax](https://docs.rs/regex/latest/regex/#syntax) and
|
||||
/// [replacement string syntax](https://docs.rs/regex/latest/regex/struct.Regex.html#replacement-string-syntax).
|
||||
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
|
||||
let regex = Regex::new(pattern).expect("Invalid regex pattern provided");
|
||||
self.key_remap.push((regex, replacement.to_string()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Enables printing of debug information during loading.
|
||||
pub fn with_debug_print(mut self) -> Self {
|
||||
self.debug = true;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the adapter type to use for loading tensors.
|
||||
pub fn with_adapter_type(mut self, adapter_type: AdapterType) -> Self {
|
||||
self.adapter_type = adapter_type;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PathBuf> for LoadArgs {
|
||||
fn from(val: PathBuf) -> Self {
|
||||
LoadArgs::new(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for LoadArgs {
|
||||
fn from(val: String) -> Self {
|
||||
LoadArgs::new(val.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for LoadArgs {
|
||||
fn from(val: &str) -> Self {
|
||||
LoadArgs::new(val.into())
|
||||
}
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
[package]
|
||||
authors = [
|
||||
"Dilshod Tadjibaev (@antimora)",
|
||||
"Nathaniel Simard (@nathanielsimard)",
|
||||
]
|
||||
description = "Library for importing ONNX models into the Burn framework"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
name = "burn-onnx"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-onnx"
|
||||
documentation = "https://docs.rs/burn-onnx"
|
||||
version.workspace = true
|
||||
|
||||
default-run = "onnx2burn"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = ["mmap"]
|
||||
mmap = ["onnx-ir/mmap"]
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../burn", version = "=0.21.0", default-features = false, features = [
|
||||
"std",
|
||||
] }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0", default-features = false }
|
||||
burn-store = { path = "../burn-store", version = "=0.21.0", default-features = false, features = [
|
||||
"std",
|
||||
"burnpack",
|
||||
] }
|
||||
onnx-ir = { path = "../onnx-ir", version = "=0.21.0", default-features = false }
|
||||
derive-new = { workspace = true }
|
||||
log = { workspace = true }
|
||||
proc-macro2 = { workspace = true }
|
||||
quote = { workspace = true }
|
||||
rust-format = { workspace = true, features = ["pretty_please", "post_process"] }
|
||||
tracing-core = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = [
|
||||
"default",
|
||||
"fmt",
|
||||
"env-filter",
|
||||
] }
|
||||
syn = { workspace = true, features = ["parsing"] }
|
||||
|
||||
[dev-dependencies]
|
||||
insta = { workspace = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["default"]
|
||||
rustdoc-args = ["--cfg", "docsrs"]
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-APACHE
|
||||
@@ -1 +0,0 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -1,13 +0,0 @@
|
||||
# Burn ONNX
|
||||
|
||||
The `burn-onnx` crate enables seamless integration of pre-trained models from popular machine
|
||||
learning frameworks into the Burn ecosystem via the ONNX format.
|
||||
|
||||
## ONNX Contributor Resources
|
||||
|
||||
- [ONNX to Burn conversion guide](https://burn.dev/books/contributor/guides/onnx-to-burn-conversion-tool.html) -
|
||||
Instructions for adding support for additional ONNX operators
|
||||
- [ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/onnx-tests/README.md) -
|
||||
Testing procedures for ONNX operators
|
||||
- [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/SUPPORTED-ONNX-OPS.md) -
|
||||
Complete list of currently supported ONNX operators
|
||||
@@ -1,416 +0,0 @@
|
||||
# Supported ONNX Operators
|
||||
|
||||
This table lists the support status for ONNX operators in Burn. Note that some
|
||||
entries marked with dimensional suffixes (such as `Conv1d`, `Conv2d`, etc.) or
|
||||
other specialized names like `Linear` are not standard ONNX operators. These
|
||||
represent Burn's implementation of dimension-specific versions of the
|
||||
corresponding ONNX operators to make the mapping clearer between ONNX and Burn
|
||||
functionality.
|
||||
|
||||
| ONNX OP | Import Support | Burn Support |
|
||||
|----------------------------------|:--------------:|:------------:|
|
||||
| [Abs][1] | ✅ | ✅ |
|
||||
| [Acos][2] | ❌ | ✅ |
|
||||
| [Acosh][3] | ❌ | ✅ |
|
||||
| [Add][4] | ✅ | ✅ |
|
||||
| [AffineGrid][195] | ❌ | ❌ |
|
||||
| [And][5] | ✅ | ✅ |
|
||||
| [ArgMax][6] | ✅ | ✅ |
|
||||
| [ArgMin][7] | ✅ | ✅ |
|
||||
| [Asin][8] | ❌ | ✅ |
|
||||
| [Asinh][9] | ❌ | ✅ |
|
||||
| [Atan][10] | ❌ | ✅ |
|
||||
| [Atanh][11] | ❌ | ✅ |
|
||||
| [Attention][194] | ✅ | ✅ |
|
||||
| [AveragePool1d][12] | ✅ | ✅ |
|
||||
| [AveragePool2d][12] | ✅ | ✅ |
|
||||
| [BatchNormalization][14] | ✅ | ✅ |
|
||||
| [Bernoulli][15] | ✅ | ✅ |
|
||||
| [BitShift][16] | ✅ | ✅ |
|
||||
| [BitwiseAnd][17] | ✅ | ✅ |
|
||||
| [BitwiseNot][18] | ✅ | ✅ |
|
||||
| [BitwiseOr][19] | ✅ | ✅ |
|
||||
| [BitwiseXor][20] | ✅ | ✅ |
|
||||
| [BlackmanWindow][21] | ❌ | ❌ |
|
||||
| [Cast][22] | ✅ | ✅ |
|
||||
| [CastLike][23] | ❌ | ❌ |
|
||||
| [Ceil][24] | ✅ | ✅ |
|
||||
| [Celu][25] | ❌ | ❌ |
|
||||
| [CenterCropPad][26] | ❌ | ❌ |
|
||||
| [Clip][27] | ✅ | ✅ |
|
||||
| [Col2Im][28] | ❌ | ❌ |
|
||||
| [Compress][29] | ❌ | ❌ |
|
||||
| [Concat][30] | ✅ | ✅ |
|
||||
| [ConcatFromSequence][31] | ❌ | ❌ |
|
||||
| [Constant][32] | ✅ | ✅ |
|
||||
| [ConstantOfShape][33] | ✅ | ✅ |
|
||||
| [Conv1d][34] | ✅ | ✅ |
|
||||
| [Conv2d][34] | ✅ | ✅ |
|
||||
| [Conv3d][34] | ✅ | ✅ |
|
||||
| [ConvInteger][37] | ❌ | ❌ |
|
||||
| [ConvTranspose1d][38] | ✅ | ✅ |
|
||||
| [ConvTranspose2d][38] | ✅ | ✅ |
|
||||
| [ConvTranspose3d][38] | ✅ | ✅ |
|
||||
| [Cos][39] | ✅ | ✅ |
|
||||
| [Cosh][40] | ✅ | ✅ |
|
||||
| [CumSum][41] | ✅ | ✅ |
|
||||
| [DeformConv][196] | ❌ | ❌ |
|
||||
| [DepthToSpace][42] | ✅ | ✅ |
|
||||
| [DequantizeLinear][43] | ❌ | ❌ |
|
||||
| [Det][44] | ❌ | ❌ |
|
||||
| [DFT][45] | ❌ | ❌ |
|
||||
| [Div][46] | ✅ | ✅ |
|
||||
| [Dropout][47] | ✅ | ✅ |
|
||||
| [DynamicQuantizeLinear][48] | ❌ | ❌ |
|
||||
| [Einsum][49] | ❌ | ❌ |
|
||||
| [Elu][50] | ❌ | ❌ |
|
||||
| [Equal][51] | ✅ | ✅ |
|
||||
| [Erf][52] | ✅ | ✅ |
|
||||
| [Exp][53] | ✅ | ✅ |
|
||||
| [Expand][54] | ✅ | ✅ |
|
||||
| [EyeLike][55] | ✅ | ✅ |
|
||||
| [Flatten][56] | ✅ | ✅ |
|
||||
| [Floor][57] | ✅ | ✅ |
|
||||
| [Gather][58] | ✅ | ✅ |
|
||||
| [GatherElements][59] | ✅ | ✅ |
|
||||
| [GatherND][60] | ❌ | ❌ |
|
||||
| [Gelu][61] | ✅ | ✅ |
|
||||
| [Gemm][62] | ✅ | ✅ |
|
||||
| [GlobalAveragePool][63] | ✅ | ✅ |
|
||||
| [GlobalLpPool][64] | ❌ | ❌ |
|
||||
| [GlobalMaxPool][65] | ❌ | ❌ |
|
||||
| [Greater][66] | ✅ | ✅ |
|
||||
| [GreaterOrEqual][67] | ✅ | ✅ |
|
||||
| [GridSample][68] | ✅ | ✅ |
|
||||
| [GroupNormalization][69] | ✅ | ✅ |
|
||||
| [GRU][70] | ❌ | ✅ |
|
||||
| [HammingWindow][71] | ❌ | ❌ |
|
||||
| [HannWindow][72] | ❌ | ❌ |
|
||||
| [Hardmax][73] | ❌ | ❌ |
|
||||
| [HardSigmoid][74] | ✅ | ✅ |
|
||||
| [HardSwish][75] | ✅ | ✅ |
|
||||
| [Identity][76] | ✅ | ✅ |
|
||||
| [If][77] | ❌ | ✅ |
|
||||
| [Im][78] | ❌ | ❌ |
|
||||
| [ImageDecoder][197] | ❌ | ❌ |
|
||||
| [InstanceNormalization][79] | ✅ | ✅ |
|
||||
| [IsInf][80] | ✅ | ✅ |
|
||||
| [IsNaN][81] | ✅ | ✅ |
|
||||
| [LayerNormalization][82] | ✅ | ✅ |
|
||||
| [LeakyRelu][83] | ✅ | ✅ |
|
||||
| [Less][84] | ✅ | ✅ |
|
||||
| [LessOrEqual][85] | ✅ | ✅ |
|
||||
| Linear | ✅ | ✅ |
|
||||
| [Log][87] | ✅ | ✅ |
|
||||
| [LogSoftmax][88] | ✅ | ✅ |
|
||||
| [Loop][89] | ✅ | ✅ |
|
||||
| [LpNormalization][90] | ❌ | ❌ |
|
||||
| [LpPool][91] | ❌ | ❌ |
|
||||
| [LRN][92] | ❌ | ❌ |
|
||||
| [LSTM][93] | ✅ | ✅ |
|
||||
| [MatMul][94] | ✅ | ✅ |
|
||||
| [MatMulInteger][95] | ✅ | ✅ |
|
||||
| [Max][96] | ✅ | ✅ |
|
||||
| [MaxPool1d][97] | ✅ | ✅ |
|
||||
| [MaxPool2d][98] | ✅ | ✅ |
|
||||
| [MaxRoiPool][99] | ❌ | ❌ |
|
||||
| [MaxUnpool][100] | ❌ | ❌ |
|
||||
| [Mean][101] | ✅ | ✅ |
|
||||
| [MeanVarianceNormalization][102] | ❌ | ❌ |
|
||||
| [MelWeightMatrix][103] | ❌ | ❌ |
|
||||
| [Min][104] | ✅ | ✅ |
|
||||
| [Mish][105] | ❌ | ❌ |
|
||||
| [Mod][106] | ✅ | ✅ |
|
||||
| [Mul][107] | ✅ | ✅ |
|
||||
| [Multinomial][108] | ❌ | ❌ |
|
||||
| [Neg][109] | ✅ | ✅ |
|
||||
| [NegativeLogLikelihoodLoss][110] | ❌ | ❌ |
|
||||
| [NonMaxSuppression][112] | ❌ | ❌ |
|
||||
| [NonZero][113] | ✅ | ✅ |
|
||||
| [Not][114] | ✅ | ✅ |
|
||||
| [OneHot][115] | ✅ | ✅ |
|
||||
| [Optional][116] | ❌ | ❌ |
|
||||
| [OptionalGetElement][117] | ❌ | ❌ |
|
||||
| [OptionalHasElement][118] | ❌ | ❌ |
|
||||
| [Or][119] | ✅ | ✅ |
|
||||
| [Pad][120] | ✅ | ✅ |
|
||||
| [Pow][121] | ✅ | ✅ |
|
||||
| [PRelu][122] | ✅ | ✅ |
|
||||
| [QLinearConv][123] | ❌ | ❌ |
|
||||
| [QLinearMatMul][124] | ❌ | ❌ |
|
||||
| [QuantizeLinear][125] | ❌ | ❌ |
|
||||
| [RMSNormalization][198] | ❌ | ❌ |
|
||||
| [RNN][145] | ❌ | ✅ |
|
||||
| [RandomNormal][126] | ✅ | ✅ |
|
||||
| [RandomNormalLike][127] | ✅ | ✅ |
|
||||
| [RandomUniform][128] | ✅ | ✅ |
|
||||
| [RandomUniformLike][129] | ✅ | ✅ |
|
||||
| [Range][130] | ✅ | ✅ |
|
||||
| [Reciprocal][131] | ✅ | ✅ |
|
||||
| [ReduceL][132] | ✅ | ✅ |
|
||||
| [ReduceLogSum][133] | ✅ | ✅ |
|
||||
| [ReduceLogSumExp][134] | ✅ | ✅ |
|
||||
| [ReduceMax][135] | ✅ | ✅ |
|
||||
| [ReduceMean][136] | ✅ | ✅ |
|
||||
| [ReduceMin][137] | ✅ | ✅ |
|
||||
| [ReduceProd][138] | ✅ | ✅ |
|
||||
| [ReduceSum][139] | ✅ | ✅ |
|
||||
| [ReduceSumSquare][140] | ✅ | ✅ |
|
||||
| [RegexFullMatch][199] | ❌ | ❌ |
|
||||
| [Relu][141] | ✅ | ✅ |
|
||||
| [Reshape][142] | ✅ | ✅ |
|
||||
| [Resize][143] | ✅ | ✅ |
|
||||
| [ReverseSequence][144] | ❌ | ❌ |
|
||||
| [RoiAlign][146] | ❌ | ❌ |
|
||||
| [RotaryEmbedding][200] | ❌ | ❌ |
|
||||
| [Round][147] | ✅ | ✅ |
|
||||
| [Scan][148] | ✅ | ✅ |
|
||||
| [Scatter][149] | ❌ | ✅ |
|
||||
| [ScatterElements][150] | ❌ | ❌ |
|
||||
| [ScatterND][151] | ❌ | ❌ |
|
||||
| [Selu][152] | ❌ | ❌ |
|
||||
| [SequenceAt][153] | ❌ | ❌ |
|
||||
| [SequenceConstruct][154] | ❌ | ❌ |
|
||||
| [SequenceEmpty][155] | ❌ | ❌ |
|
||||
| [SequenceErase][156] | ❌ | ❌ |
|
||||
| [SequenceInsert][157] | ❌ | ❌ |
|
||||
| [SequenceLength][158] | ❌ | ❌ |
|
||||
| [SequenceMap][159] | ❌ | ❌ |
|
||||
| [Shape][160] | ✅ | ✅ |
|
||||
| [Shrink][161] | ❌ | ❌ |
|
||||
| [Sigmoid][162] | ✅ | ✅ |
|
||||
| [Sign][163] | ✅ | ✅ |
|
||||
| [Sin][164] | ✅ | ✅ |
|
||||
| [Sinh][165] | ✅ | ✅ |
|
||||
| [Size][166] | ✅ | ✅ |
|
||||
| [Slice][167] | ✅ | ✅ |
|
||||
| [Softmax][168] | ✅ | ✅ |
|
||||
| [SoftmaxCrossEntropyLoss][169] | ❌ | ❌ |
|
||||
| [Softplus][170] | ❌ | ❌ |
|
||||
| [Softsign][171] | ❌ | ❌ |
|
||||
| [SpaceToDepth][172] | ✅ | ✅ |
|
||||
| [Split][173] | ✅ | ✅ |
|
||||
| [SplitToSequence][174] | ❌ | ❌ |
|
||||
| [Sqrt][175] | ✅ | ✅ |
|
||||
| [Squeeze][176] | ✅ | ✅ |
|
||||
| [STFT][177] | ❌ | ❌ |
|
||||
| [StringConcat][201] | ❌ | ❌ |
|
||||
| [StringNormalizer][178] | ❌ | ❌ |
|
||||
| [StringSplit][202] | ❌ | ❌ |
|
||||
| [Sub][179] | ✅ | ✅ |
|
||||
| [Sum][180] | ✅ | ✅ |
|
||||
| [Swish][203] | ❌ | ❌ |
|
||||
| [Tan][181] | ✅ | ✅ |
|
||||
| [Tanh][182] | ✅ | ✅ |
|
||||
| [TensorScatter][204] | ❌ | ❌ |
|
||||
| [TfIdfVectorizer][183] | ❌ | ❌ |
|
||||
| [ThresholdedRelu][184] | ❌ | ❌ |
|
||||
| [Tile][185] | ✅ | ✅ |
|
||||
| [TopK][186] | ✅ | ✅ |
|
||||
| [Transpose][187] | ✅ | ✅ |
|
||||
| [Trilu][188] | ✅ | ✅ |
|
||||
| [Unique][189] | ❌ | ❌ |
|
||||
| [Upsample][190] | ❌ | ❌ |
|
||||
| [Where][191] | ✅ | ✅ |
|
||||
| [Xor][192] | ✅ | ✅ |
|
||||
| [Unsqueeze][193] | ✅ | ✅ |
|
||||
|
||||
[1]: https://onnx.ai/onnx/operators/onnx__Abs.html "ONNX Abs"
|
||||
[2]: https://onnx.ai/onnx/operators/onnx__Acos.html "ONNX Acos"
|
||||
[3]: https://onnx.ai/onnx/operators/onnx__Acosh.html "ONNX Acosh"
|
||||
[4]: https://onnx.ai/onnx/operators/onnx__Add.html "ONNX Add"
|
||||
[5]: https://onnx.ai/onnx/operators/onnx__And.html "ONNX And"
|
||||
[6]: https://onnx.ai/onnx/operators/onnx__ArgMax.html "ONNX ArgMax"
|
||||
[7]: https://onnx.ai/onnx/operators/onnx__ArgMin.html "ONNX ArgMin"
|
||||
[8]: https://onnx.ai/onnx/operators/onnx__Asin.html "ONNX Asin"
|
||||
[9]: https://onnx.ai/onnx/operators/onnx__Asinh.html "ONNX Asinh"
|
||||
[10]: https://onnx.ai/onnx/operators/onnx__Atan.html "ONNX Atan"
|
||||
[11]: https://onnx.ai/onnx/operators/onnx__Atanh.html "ONNX Atanh"
|
||||
[12]: https://onnx.ai/onnx/operators/onnx__AveragePool.html "ONNX AveragePool"
|
||||
[14]: https://onnx.ai/onnx/operators/onnx__BatchNormalization.html "ONNX BatchNormalization"
|
||||
[15]: https://onnx.ai/onnx/operators/onnx__Bernoulli.html "ONNX Bernoulli"
|
||||
[16]: https://onnx.ai/onnx/operators/onnx__BitShift.html "ONNX BitShift"
|
||||
[17]: https://onnx.ai/onnx/operators/onnx__BitwiseAnd.html "ONNX BitwiseAnd"
|
||||
[18]: https://onnx.ai/onnx/operators/onnx__BitwiseNot.html "ONNX BitwiseNot"
|
||||
[19]: https://onnx.ai/onnx/operators/onnx__BitwiseOr.html "ONNX BitwiseOr"
|
||||
[20]: https://onnx.ai/onnx/operators/onnx__BitwiseXor.html "ONNX BitwiseXor"
|
||||
[21]: https://onnx.ai/onnx/operators/onnx__BlackmanWindow.html "ONNX BlackmanWindow"
|
||||
[22]: https://onnx.ai/onnx/operators/onnx__Cast.html "ONNX Cast"
|
||||
[23]: https://onnx.ai/onnx/operators/onnx__CastLike.html "ONNX CastLike"
|
||||
[24]: https://onnx.ai/onnx/operators/onnx__Ceil.html "ONNX Ceil"
|
||||
[25]: https://onnx.ai/onnx/operators/onnx__Celu.html "ONNX Celu"
|
||||
[26]: https://onnx.ai/onnx/operators/onnx__CenterCropPad.html "ONNX CenterCropPad"
|
||||
[27]: https://onnx.ai/onnx/operators/onnx__Clip.html "ONNX Clip"
|
||||
[28]: https://onnx.ai/onnx/operators/onnx__Col2Im.html "ONNX Col2Im"
|
||||
[29]: https://onnx.ai/onnx/operators/onnx__Compress.html "ONNX Compress"
|
||||
[30]: https://onnx.ai/onnx/operators/onnx__Concat.html "ONNX Concat"
|
||||
[31]: https://onnx.ai/onnx/operators/onnx__ConcatFromSequence.html "ONNX ConcatFromSequence"
|
||||
[32]: https://onnx.ai/onnx/operators/onnx__Constant.html "ONNX Constant"
|
||||
[33]: https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html "ONNX ConstantOfShape"
|
||||
[34]: https://onnx.ai/onnx/operators/onnx__Conv.html "ONNX Conv"
|
||||
[37]: https://onnx.ai/onnx/operators/onnx__ConvInteger.html "ONNX ConvInteger"
|
||||
[38]: https://onnx.ai/onnx/operators/onnx__ConvTranspose.html "ONNX ConvTranspose"
|
||||
[39]: https://onnx.ai/onnx/operators/onnx__Cos.html "ONNX Cos"
|
||||
[40]: https://onnx.ai/onnx/operators/onnx__Cosh.html "ONNX Cosh"
|
||||
[41]: https://onnx.ai/onnx/operators/onnx__CumSum.html "ONNX CumSum"
|
||||
[42]: https://onnx.ai/onnx/operators/onnx__DepthToSpace.html "ONNX DepthToSpace"
|
||||
[43]: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html "ONNX DequantizeLinear"
|
||||
[44]: https://onnx.ai/onnx/operators/onnx__Det.html "ONNX Det"
|
||||
[45]: https://onnx.ai/onnx/operators/onnx__DFT.html "ONNX DFT"
|
||||
[46]: https://onnx.ai/onnx/operators/onnx__Div.html "ONNX Div"
|
||||
[47]: https://onnx.ai/onnx/operators/onnx__Dropout.html "ONNX Dropout"
|
||||
[48]: https://onnx.ai/onnx/operators/onnx__DynamicQuantizeLinear.html "ONNX DynamicQuantizeLinear"
|
||||
[49]: https://onnx.ai/onnx/operators/onnx__Einsum.html "ONNX Einsum"
|
||||
[50]: https://onnx.ai/onnx/operators/onnx__Elu.html "ONNX Elu"
|
||||
[51]: https://onnx.ai/onnx/operators/onnx__Equal.html "ONNX Equal"
|
||||
[52]: https://onnx.ai/onnx/operators/onnx__Erf.html "ONNX Erf"
|
||||
[53]: https://onnx.ai/onnx/operators/onnx__Exp.html "ONNX Exp"
|
||||
[54]: https://onnx.ai/onnx/operators/onnx__Expand.html "ONNX Expand"
|
||||
[55]: https://onnx.ai/onnx/operators/onnx__EyeLike.html "ONNX EyeLike"
|
||||
[56]: https://onnx.ai/onnx/operators/onnx__Flatten.html "ONNX Flatten"
|
||||
[57]: https://onnx.ai/onnx/operators/onnx__Floor.html "ONNX Floor"
|
||||
[58]: https://onnx.ai/onnx/operators/onnx__Gather.html "ONNX Gather"
|
||||
[59]: https://onnx.ai/onnx/operators/onnx__GatherElements.html "ONNX GatherElements"
|
||||
[60]: https://onnx.ai/onnx/operators/onnx__GatherND.html "ONNX GatherND"
|
||||
[61]: https://onnx.ai/onnx/operators/onnx__Gelu.html "ONNX Gelu"
|
||||
[62]: https://onnx.ai/onnx/operators/onnx__Gemm.html "ONNX Gemm (Linear Layer)"
|
||||
[63]: https://onnx.ai/onnx/operators/onnx__GlobalAveragePool.html "ONNX GlobalAveragePool"
|
||||
[64]: https://onnx.ai/onnx/operators/onnx__GlobalLpPool.html "ONNX GlobalLpPool"
|
||||
[65]: https://onnx.ai/onnx/operators/onnx__GlobalMaxPool.html "ONNX GlobalMaxPool"
|
||||
[66]: https://onnx.ai/onnx/operators/onnx__Greater.html "ONNX Greater"
|
||||
[67]: https://onnx.ai/onnx/operators/onnx__GreaterOrEqual.html "ONNX GreaterOrEqual"
|
||||
[68]: https://onnx.ai/onnx/operators/onnx__GridSample.html "ONNX GridSample"
|
||||
[69]: https://onnx.ai/onnx/operators/onnx__GroupNormalization.html "ONNX GroupNormalization"
|
||||
[70]: https://onnx.ai/onnx/operators/onnx__GRU.html "ONNX GRU"
|
||||
[71]: https://onnx.ai/onnx/operators/onnx__HammingWindow.html "ONNX HammingWindow"
|
||||
[72]: https://onnx.ai/onnx/operators/onnx__HannWindow.html "ONNX HannWindow"
|
||||
[73]: https://onnx.ai/onnx/operators/onnx__Hardmax.html "ONNX Hardmax"
|
||||
[74]: https://onnx.ai/onnx/operators/onnx__HardSigmoid.html "ONNX HardSigmoid"
|
||||
[75]: https://onnx.ai/onnx/operators/onnx__HardSwish.html "ONNX HardSwish"
|
||||
[76]: https://onnx.ai/onnx/operators/onnx__Identity.html "ONNX Identity"
|
||||
[77]: https://onnx.ai/onnx/operators/onnx__If.html "ONNX If"
|
||||
[78]: https://onnx.ai/onnx/operators/onnx__Im.html "ONNX Im"
|
||||
[79]: https://onnx.ai/onnx/operators/onnx__InstanceNormalization.html "ONNX InstanceNormalization"
|
||||
[80]: https://onnx.ai/onnx/operators/onnx__IsInf.html "ONNX IsInf"
|
||||
[81]: https://onnx.ai/onnx/operators/onnx__IsNaN.html "ONNX IsNaN"
|
||||
[82]: https://onnx.ai/onnx/operators/onnx__LayerNormalization.html "ONNX LayerNormalization"
|
||||
[83]: https://onnx.ai/onnx/operators/onnx__LeakyRelu.html "ONNX LeakyRelu"
|
||||
[84]: https://onnx.ai/onnx/operators/onnx__Less.html "ONNX Less"
|
||||
[85]: https://onnx.ai/onnx/operators/onnx__LessOrEqual.html "ONNX LessOrEqual"
|
||||
[87]: https://onnx.ai/onnx/operators/onnx__Log.html "ONNX Log"
|
||||
[88]: https://onnx.ai/onnx/operators/onnx__LogSoftmax.html "ONNX LogSoftmax"
|
||||
[89]: https://onnx.ai/onnx/operators/onnx__Loop.html "ONNX Loop"
|
||||
[90]: https://onnx.ai/onnx/operators/onnx__LpNormalization.html "ONNX LpNormalization"
|
||||
[91]: https://onnx.ai/onnx/operators/onnx__LpPool.html "ONNX LpPool"
|
||||
[92]: https://onnx.ai/onnx/operators/onnx__LRN.html "ONNX LRN"
|
||||
[93]: https://onnx.ai/onnx/operators/onnx__LSTM.html "ONNX LSTM"
|
||||
[94]: https://onnx.ai/onnx/operators/onnx__MatMul.html "ONNX MatMul"
|
||||
[95]: https://onnx.ai/onnx/operators/onnx__MatMulInteger.html "ONNX MatMulInteger"
|
||||
[96]: https://onnx.ai/onnx/operators/onnx__Max.html "ONNX Max"
|
||||
[97]: https://onnx.ai/onnx/operators/onnx__MaxPool.html "ONNX MaxPool1d"
|
||||
[98]: https://onnx.ai/onnx/operators/onnx__MaxPool.html "ONNX MaxPool2d"
|
||||
[99]: https://onnx.ai/onnx/operators/onnx__MaxRoiPool.html "ONNX MaxRoiPool"
|
||||
[100]: https://onnx.ai/onnx/operators/onnx__MaxUnpool.html "ONNX MaxUnpool"
|
||||
[101]: https://onnx.ai/onnx/operators/onnx__Mean.html "ONNX Mean"
|
||||
[102]: https://onnx.ai/onnx/operators/onnx__MeanVarianceNormalization.html "ONNX MeanVarianceNormalization"
|
||||
[103]: https://onnx.ai/onnx/operators/onnx__MelWeightMatrix.html "ONNX MelWeightMatrix"
|
||||
[104]: https://onnx.ai/onnx/operators/onnx__Min.html "ONNX Min"
|
||||
[105]: https://onnx.ai/onnx/operators/onnx__Mish.html "ONNX Mish"
|
||||
[106]: https://onnx.ai/onnx/operators/onnx__Mod.html "ONNX Mod"
|
||||
[107]: https://onnx.ai/onnx/operators/onnx__Mul.html "ONNX Mul"
|
||||
[108]: https://onnx.ai/onnx/operators/onnx__Multinomial.html "ONNX Multinomial"
|
||||
[109]: https://onnx.ai/onnx/operators/onnx__Neg.html "ONNX Neg"
|
||||
[110]: https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html "ONNX NegativeLogLikelihoodLoss"
|
||||
[112]: https://onnx.ai/onnx/operators/onnx__NonMaxSuppression.html "ONNX NonMaxSuppression"
|
||||
[113]: https://onnx.ai/onnx/operators/onnx__NonZero.html "ONNX NonZero"
|
||||
[114]: https://onnx.ai/onnx/operators/onnx__Not.html "ONNX Not"
|
||||
[115]: https://onnx.ai/onnx/operators/onnx__OneHot.html "ONNX OneHot"
|
||||
[116]: https://onnx.ai/onnx/operators/onnx__Optional.html "ONNX Optional"
|
||||
[117]: https://onnx.ai/onnx/operators/onnx__OptionalGetElement.html "ONNX OptionalGetElement"
|
||||
[118]: https://onnx.ai/onnx/operators/onnx__OptionalHasElement.html "ONNX OptionalHasElement"
|
||||
[119]: https://onnx.ai/onnx/operators/onnx__Or.html "ONNX Or"
|
||||
[120]: https://onnx.ai/onnx/operators/onnx__Pad.html "ONNX Pad"
|
||||
[121]: https://onnx.ai/onnx/operators/onnx__Pow.html "ONNX Pow"
|
||||
[122]: https://onnx.ai/onnx/operators/onnx__PRelu.html "ONNX PRelu"
|
||||
[123]: https://onnx.ai/onnx/operators/onnx__QLinearConv "ONNX QLinearConv"
|
||||
[124]: https://onnx.ai/onnx/operators/onnx__QLinearMatMul.html "ONNX QLinearMatMul"
|
||||
[125]: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html "ONNX QuantizeLinear"
|
||||
[126]: https://onnx.ai/onnx/operators/onnx__RandomNormal.html "ONNX RandomNormal"
|
||||
[127]: https://onnx.ai/onnx/operators/onnx__RandomNormalLike.html "ONNX RandomNormalLike"
|
||||
[128]: https://onnx.ai/onnx/operators/onnx__RandomUniform.html "ONNX RandomUniform"
|
||||
[129]: https://onnx.ai/onnx/operators/onnx__RandomUniformLike.html "ONNX RandomUniformLike"
|
||||
[130]: https://onnx.ai/onnx/operators/onnx__Range.html "ONNX Range"
|
||||
[131]: https://onnx.ai/onnx/operators/onnx__Reciprocal.html "ONNX Reciprocal"
|
||||
[132]: https://onnx.ai/onnx/operators/onnx__ReduceL1.html "ONNX ReduceL"
|
||||
[133]: https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html "ONNX ReduceLogSum"
|
||||
[134]: https://onnx.ai/onnx/operators/onnx__ReduceLogSumExp.html "ONNX ReduceLogSumExp"
|
||||
[135]: https://onnx.ai/onnx/operators/onnx__ReduceMax.html "ONNX ReduceMax"
|
||||
[136]: https://onnx.ai/onnx/operators/onnx__ReduceMean.html "ONNX ReduceMean"
|
||||
[137]: https://onnx.ai/onnx/operators/onnx__ReduceMin.html "ONNX ReduceMin"
|
||||
[138]: https://onnx.ai/onnx/operators/onnx__ReduceProd.html "ONNX ReduceProd"
|
||||
[139]: https://onnx.ai/onnx/operators/onnx__ReduceSum.html "ONNX ReduceSum"
|
||||
[140]: https://onnx.ai/onnx/operators/onnx__ReduceSumSquare.html "ONNX ReduceSumSquare"
|
||||
[141]: https://onnx.ai/onnx/operators/onnx__Relu.html "ONNX Relu"
|
||||
[142]: https://onnx.ai/onnx/operators/onnx__Reshape.html "ONNX Reshape"
|
||||
[143]: https://onnx.ai/onnx/operators/onnx__Resize.html "ONNX Resize"
|
||||
[144]: https://onnx.ai/onnx/operators/onnx__ReverseSequence.html "ONNX ReverseSequence"
|
||||
[145]: https://onnx.ai/onnx/operators/onnx__RNN.html "ONNX RNN"
|
||||
[146]: https://onnx.ai/onnx/operators/onnx__RoiAlign.html "ONNX RoiAlign"
|
||||
[147]: https://onnx.ai/onnx/operators/onnx__Round.html "ONNX Round"
|
||||
[148]: https://onnx.ai/onnx/operators/onnx__Scan.html "ONNX Scan"
|
||||
[149]: https://onnx.ai/onnx/operators/onnx__Scatter.html "ONNX Scatter"
|
||||
[150]: https://onnx.ai/onnx/operators/onnx__ScatterElements.html "ONNX ScatterElements"
|
||||
[151]: https://onnx.ai/onnx/operators/onnx__ScatterND.html "ONNX ScatterND"
|
||||
[152]: https://onnx.ai/onnx/operators/onnx__Selu.html "ONNX Selu"
|
||||
[153]: https://onnx.ai/onnx/operators/onnx__SequenceAt.html "ONNX SequenceAt"
|
||||
[154]: https://onnx.ai/onnx/operators/onnx__SequenceConstruct.html "ONNX SequenceConstruct"
|
||||
[155]: https://onnx.ai/onnx/operators/onnx__SequenceEmpty.html "ONNX SequenceEmpty"
|
||||
[156]: https://onnx.ai/onnx/operators/onnx__SequenceErase.html "ONNX SequenceErase"
|
||||
[157]: https://onnx.ai/onnx/operators/onnx__SequenceInsert.html "ONNX SequenceInsert"
|
||||
[158]: https://onnx.ai/onnx/operators/onnx__SequenceLength.html "ONNX SequenceLength"
|
||||
[159]: https://onnx.ai/onnx/operators/onnx__SequenceMap.html "ONNX SequenceMap"
|
||||
[160]: https://onnx.ai/onnx/operators/onnx__Shape.html "ONNX Shape"
|
||||
[161]: https://onnx.ai/onnx/operators/onnx__Shrink.html "ONNX Shrink"
|
||||
[162]: https://onnx.ai/onnx/operators/onnx__Sigmoid.html "ONNX Sigmoid"
|
||||
[163]: https://onnx.ai/onnx/operators/onnx__Sign.html "ONNX Sign"
|
||||
[164]: https://onnx.ai/onnx/operators/onnx__Sin.html "ONNX Sin"
|
||||
[165]: https://onnx.ai/onnx/operators/onnx__Sinh.html "ONNX Sinh"
|
||||
[166]: https://onnx.ai/onnx/operators/onnx__Size.html "ONNX Size"
|
||||
[167]: https://onnx.ai/onnx/operators/onnx__Slice.html "ONNX Slice"
|
||||
[168]: https://onnx.ai/onnx/operators/onnx__Softmax.html "ONNX Softmax"
|
||||
[169]: https://onnx.ai/onnx/operators/onnx__SoftmaxCrossEntropyLoss.html "ONNX SoftmaxCrossEntropyLoss"
|
||||
[170]: https://onnx.ai/onnx/operators/onnx__Softplus.html "ONNX Softplus"
|
||||
[171]: https://onnx.ai/onnx/operators/onnx__Softsign.html "ONNX Softsign"
|
||||
[172]: https://onnx.ai/onnx/operators/onnx__SpaceToDepth.html "ONNX SpaceToDepth"
|
||||
[173]: https://onnx.ai/onnx/operators/onnx__Split.html "ONNX Split"
|
||||
[174]: https://onnx.ai/onnx/operators/onnx__SplitToSequence.html "ONNX SplitToSequence"
|
||||
[175]: https://onnx.ai/onnx/operators/onnx__Sqrt.html "ONNX Sqrt"
|
||||
[176]: https://onnx.ai/onnx/operators/onnx__Squeeze.html "ONNX Squeeze"
|
||||
[177]: https://onnx.ai/onnx/operators/onnx__STFT.html "ONNX STFT"
|
||||
[178]: https://onnx.ai/onnx/operators/onnx__StringNormalizer.html "ONNX StringNormalizer"
|
||||
[179]: https://onnx.ai/onnx/operators/onnx__Sub.html "ONNX Sub"
|
||||
[180]: https://onnx.ai/onnx/operators/onnx__Sum.html "ONNX Sum"
|
||||
[181]: https://onnx.ai/onnx/operators/onnx__Tan.html "ONNX Tan"
|
||||
[182]: https://onnx.ai/onnx/operators/onnx__Tanh.html "ONNX Tanh"
|
||||
[183]: https://onnx.ai/onnx/operators/onnx__TfIdfVectorizer.html "ONNX TfIdfVectorizer"
|
||||
[184]: https://onnx.ai/onnx/operators/onnx__ThresholdedRelu.html "ONNX ThresholdedRelu"
|
||||
[185]: https://onnx.ai/onnx/operators/onnx__Tile.html "ONNX Tile"
|
||||
[186]: https://onnx.ai/onnx/operators/onnx__TopK.html "ONNX TopK"
|
||||
[187]: https://onnx.ai/onnx/operators/onnx__Transpose.html "ONNX Transpose"
|
||||
[188]: https://onnx.ai/onnx/operators/onnx__Trilu.html "ONNX Trilu"
|
||||
[189]: https://onnx.ai/onnx/operators/onnx__Unique.html "ONNX Unique"
|
||||
[190]: https://onnx.ai/onnx/operators/onnx__Upsample.html "ONNX Upsample"
|
||||
[191]: https://onnx.ai/onnx/operators/onnx__Where.html "ONNX Where"
|
||||
[192]: https://onnx.ai/onnx/operators/onnx__Xor.html "ONNX Xor"
|
||||
[193]: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html "ONNX Unsqueeze"
|
||||
[194]: https://onnx.ai/onnx/operators/onnx__Attention.html "ONNX Attention"
|
||||
[195]: https://onnx.ai/onnx/operators/onnx__AffineGrid.html "ONNX AffineGrid"
|
||||
[196]: https://onnx.ai/onnx/operators/onnx__DeformConv.html "ONNX DeformConv"
|
||||
[197]: https://onnx.ai/onnx/operators/onnx__ImageDecoder.html "ONNX ImageDecoder"
|
||||
[198]: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html "ONNX RMSNormalization"
|
||||
[199]: https://onnx.ai/onnx/operators/onnx__RegexFullMatch.html "ONNX RegexFullMatch"
|
||||
[200]: https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html "ONNX RotaryEmbedding"
|
||||
[201]: https://onnx.ai/onnx/operators/onnx__StringConcat.html "ONNX StringConcat"
|
||||
[202]: https://onnx.ai/onnx/operators/onnx__StringSplit.html "ONNX StringSplit"
|
||||
[203]: https://onnx.ai/onnx/operators/onnx__Swish.html "ONNX Swish"
|
||||
[204]: https://onnx.ai/onnx/operators/onnx__TensorScatter.html "ONNX TensorScatter"
|
||||
25
crates/burn-onnx/model-checks/.gitignore
vendored
25
crates/burn-onnx/model-checks/.gitignore
vendored
@@ -1,25 +0,0 @@
|
||||
Cargo.lock
|
||||
|
||||
# Ignore model artifacts and temporary resources in all subdirectories
|
||||
*/artifacts/
|
||||
|
||||
# Ignore downloaded model files
|
||||
**/*.onnx
|
||||
**/*.pt
|
||||
**/*.pth
|
||||
**/*.ckpt
|
||||
**/*.safetensors
|
||||
**/*.pb
|
||||
**/*.h5
|
||||
|
||||
# Python cache
|
||||
**/__pycache__/
|
||||
**/*.pyc
|
||||
**/*.pyo
|
||||
**/*.pyd
|
||||
**/.Python
|
||||
|
||||
# UV/pip
|
||||
**/.venv/
|
||||
**/venv/
|
||||
**/*.egg-info/
|
||||
@@ -1,66 +0,0 @@
|
||||
# Model Checks
|
||||
|
||||
This directory contains model verification and validation tests for burn-onnx. Each subdirectory
|
||||
represents a different model that we test to ensure burn-onnx can correctly:
|
||||
|
||||
1. Import ONNX models
|
||||
2. Generate Rust code from the models
|
||||
3. Build and run the generated code
|
||||
|
||||
## Purpose
|
||||
|
||||
The model-checks serve as integration tests to verify that burn-onnx works correctly with
|
||||
real-world models. These tests help catch regressions and ensure compatibility with various ONNX
|
||||
operators and model architectures.
|
||||
|
||||
## Structure
|
||||
|
||||
Each model directory typically contains:
|
||||
|
||||
- Model download/preparation script (e.g., `get_model.py`)
|
||||
- `build.rs` - Build script that uses burn-onnx to generate Rust code
|
||||
- `src/main.rs` - Test code that runs the generated model
|
||||
- `Cargo.toml` - Package configuration
|
||||
- `artifacts/` - Directory for downloaded ONNX models (created by the script)
|
||||
|
||||
Generated files (not tracked in git):
|
||||
|
||||
- `target/` - Build artifacts and generated model code
|
||||
|
||||
## Two-Step Process
|
||||
|
||||
### Step 1: Download and Prepare the Model
|
||||
|
||||
First, download the model and convert it to the required ONNX format:
|
||||
|
||||
```bash
|
||||
cd model-checks/<model-name>
|
||||
python get_model.py
|
||||
# or using uv:
|
||||
uv run get_model.py
|
||||
```
|
||||
|
||||
The model preparation script typically:
|
||||
|
||||
- Downloads the model (if not already present)
|
||||
- Converts it to ONNX format with the appropriate opset version
|
||||
- Validates the model structure
|
||||
- Saves the prepared model to `artifacts/`
|
||||
|
||||
Scripts are designed to skip downloading if the ONNX model already exists, saving time and
|
||||
bandwidth.
|
||||
|
||||
### Step 2: Build and Run the Model
|
||||
|
||||
Once the ONNX model is ready, build and run the Rust code:
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
cargo run
|
||||
```
|
||||
|
||||
The build process will:
|
||||
|
||||
- Check that the ONNX model exists (with helpful error messages if not)
|
||||
- Generate Rust code from the ONNX model using burn-onnx
|
||||
- Compile the generated code
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "burn-onnx-model-checks-albert"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[workspace]
|
||||
|
||||
[features]
|
||||
default = ["tch"]
|
||||
ndarray = []
|
||||
tch = []
|
||||
wgpu = []
|
||||
metal = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../../crates/burn", features = [
|
||||
"ndarray",
|
||||
"tch",
|
||||
"wgpu",
|
||||
"metal",
|
||||
] }
|
||||
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../../../burn-onnx" }
|
||||
@@ -1,69 +0,0 @@
|
||||
# ALBERT Model Checks
|
||||
|
||||
This crate provides a unified interface for testing ALBERT model variants with Burn.
|
||||
|
||||
## Supported Models
|
||||
|
||||
- `albert-base-v2` - ALBERT Base v2 from HuggingFace (https://huggingface.co/albert/albert-base-v2)
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Download and prepare a model
|
||||
|
||||
```bash
|
||||
# Using Python directly
|
||||
python get_model.py --model albert-base-v2
|
||||
|
||||
# Or using uv
|
||||
uv run get_model.py --model albert-base-v2
|
||||
|
||||
# List available models
|
||||
uv run get_model.py --list
|
||||
```
|
||||
|
||||
### 2. Build and run the model test
|
||||
|
||||
```bash
|
||||
# Build the model
|
||||
ALBERT_MODEL=albert-base-v2 cargo build
|
||||
|
||||
# Run the test
|
||||
ALBERT_MODEL=albert-base-v2 cargo run --release
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
albert/
|
||||
├── artifacts/ # Downloaded ONNX models and test data
|
||||
│ ├── albert-base-v2_opset16.onnx
|
||||
│ ├── albert-base-v2_test_data.pt
|
||||
│ └── ...
|
||||
├── src/
|
||||
│ └── main.rs # Test runner
|
||||
├── build.rs # Build script that generates model code
|
||||
├── get_model.py # Model download and preparation script
|
||||
└── Cargo.toml
|
||||
```
|
||||
|
||||
## Model Architecture
|
||||
|
||||
ALBERT (A Lite BERT) is a lighter version of BERT that uses parameter-sharing techniques to reduce
|
||||
model size while maintaining performance. The model has:
|
||||
|
||||
- **Inputs**:
|
||||
- `input_ids`: Token IDs (shape: [batch_size, sequence_length])
|
||||
- `attention_mask`: Attention mask (shape: [batch_size, sequence_length])
|
||||
- `token_type_ids`: Token type IDs (shape: [batch_size, sequence_length])
|
||||
|
||||
- **Outputs**:
|
||||
- `last_hidden_state`: Sequence of hidden states (shape: [batch_size, sequence_length,
|
||||
hidden_size])
|
||||
- `pooler_output`: Pooled output for classification tasks (shape: [batch_size, hidden_size])
|
||||
|
||||
## Notes
|
||||
|
||||
- The default sequence length is 128 tokens
|
||||
- ALBERT Base v2 has a hidden size of 768
|
||||
- The model uses ONNX opset 16
|
||||
- Test data is generated with random inputs for reproducibility (seed=42)
|
||||
@@ -1,85 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
fn main() {
|
||||
// Supported models
|
||||
let supported_models = vec!["albert-base-v2"];
|
||||
|
||||
// Get the model name from environment variable (required)
|
||||
let model_name = env::var("ALBERT_MODEL").unwrap_or_else(|_| {
|
||||
eprintln!("Error: ALBERT_MODEL environment variable is not set.");
|
||||
eprintln!();
|
||||
eprintln!("Please specify which ALBERT model to build:");
|
||||
eprintln!(" ALBERT_MODEL=albert-base-v2 cargo build");
|
||||
eprintln!();
|
||||
eprintln!("Available models: {}", supported_models.join(", "));
|
||||
std::process::exit(1);
|
||||
});
|
||||
|
||||
if !supported_models.contains(&model_name.as_str()) {
|
||||
eprintln!(
|
||||
"Error: Unsupported model '{}'. Supported models: {:?}",
|
||||
model_name, supported_models
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let onnx_path = format!("artifacts/{}_opset16.onnx", model_name);
|
||||
let test_data_path = format!("artifacts/{}_test_data.pt", model_name);
|
||||
|
||||
// Tell Cargo to only rebuild if these files change
|
||||
println!("cargo:rerun-if-changed={}", onnx_path);
|
||||
println!("cargo:rerun-if-changed={}", test_data_path);
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rerun-if-env-changed=ALBERT_MODEL");
|
||||
|
||||
// Check if the ONNX model file exists
|
||||
if !Path::new(&onnx_path).exists() {
|
||||
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
|
||||
eprintln!();
|
||||
eprintln!(
|
||||
"Please run the following command to download and prepare the {} model:",
|
||||
model_name
|
||||
);
|
||||
eprintln!(" python get_model.py --model {}", model_name);
|
||||
eprintln!();
|
||||
eprintln!("Or if you prefer using uv:");
|
||||
eprintln!(" uv run get_model.py --model {}", model_name);
|
||||
eprintln!();
|
||||
eprintln!("Available models: {}", supported_models.join(", "));
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Generate the model code from the ONNX file
|
||||
ModelGen::new()
|
||||
.input(&onnx_path)
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
|
||||
// Write the model name to a file so main.rs can access it
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let model_info_path = Path::new(&out_dir).join("model_info.rs");
|
||||
|
||||
// Generate the include path for the model
|
||||
let model_include = format!(
|
||||
"include!(concat!(env!(\"OUT_DIR\"), \"/model/{}_opset16.rs\"));",
|
||||
model_name
|
||||
);
|
||||
|
||||
fs::write(
|
||||
model_info_path,
|
||||
format!(
|
||||
r#"pub const MODEL_NAME: &str = "{}";
|
||||
pub const TEST_DATA_FILE: &str = "{}_test_data.pt";
|
||||
|
||||
// Include the generated model
|
||||
pub mod albert_model {{
|
||||
{}
|
||||
}}"#,
|
||||
model_name, model_name, model_include
|
||||
),
|
||||
)
|
||||
.expect("Failed to write model info");
|
||||
}
|
||||
@@ -1,234 +0,0 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "onnx>=1.17.0",
|
||||
# "onnxruntime>=1.18.0",
|
||||
# "transformers>=4.44.0",
|
||||
# "sentencepiece>=0.2.0",
|
||||
# "numpy",
|
||||
# "torch",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import os
|
||||
import sys
|
||||
import onnx
|
||||
from onnx import shape_inference, version_converter
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
# Supported ALBERT models configuration
|
||||
SUPPORTED_MODELS = {
|
||||
'albert-base-v2': {
|
||||
'hf_name': 'albert/albert-base-v2',
|
||||
'display_name': 'ALBERT Base v2',
|
||||
'seq_length': 128,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def download_and_convert_model(model_name, output_path):
|
||||
"""Download ALBERT model from HuggingFace and export to ONNX format."""
|
||||
from transformers import AlbertModel, AlbertTokenizer
|
||||
import torch
|
||||
|
||||
model_config = SUPPORTED_MODELS[model_name]
|
||||
display_name = model_config['display_name']
|
||||
hf_name = model_config['hf_name']
|
||||
seq_length = model_config['seq_length']
|
||||
|
||||
print(f"Downloading {display_name} model from HuggingFace...")
|
||||
tokenizer = AlbertTokenizer.from_pretrained(hf_name)
|
||||
model = AlbertModel.from_pretrained(hf_name)
|
||||
model.eval()
|
||||
|
||||
print("Exporting to ONNX format...")
|
||||
|
||||
# Create dummy inputs
|
||||
dummy_text = "This is a sample text for ONNX export."
|
||||
inputs = tokenizer(
|
||||
dummy_text,
|
||||
padding='max_length',
|
||||
max_length=seq_length,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
input_ids = inputs['input_ids']
|
||||
attention_mask = inputs['attention_mask']
|
||||
token_type_ids = inputs['token_type_ids']
|
||||
|
||||
# Export to ONNX
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_ids, attention_mask, token_type_ids),
|
||||
output_path,
|
||||
input_names=['input_ids', 'attention_mask', 'token_type_ids'],
|
||||
output_names=['last_hidden_state', 'pooler_output'],
|
||||
dynamic_axes={
|
||||
'input_ids': {0: 'batch_size', 1: 'sequence'},
|
||||
'attention_mask': {0: 'batch_size', 1: 'sequence'},
|
||||
'token_type_ids': {0: 'batch_size', 1: 'sequence'},
|
||||
'last_hidden_state': {0: 'batch_size', 1: 'sequence'},
|
||||
'pooler_output': {0: 'batch_size'},
|
||||
},
|
||||
opset_version=16,
|
||||
do_constant_folding=True,
|
||||
)
|
||||
|
||||
if not output_path.exists():
|
||||
raise FileNotFoundError(f"Failed to create ONNX file at {output_path}")
|
||||
|
||||
|
||||
def process_model(input_path, output_path, target_opset=16):
|
||||
"""Load, upgrade opset, and apply shape inference to model."""
|
||||
print(f"Loading model from {input_path}...")
|
||||
model = onnx.load(input_path)
|
||||
|
||||
# Check and upgrade opset if needed
|
||||
current_opset = model.opset_import[0].version
|
||||
if current_opset < target_opset:
|
||||
print(f"Upgrading opset from {current_opset} to {target_opset}...")
|
||||
model = version_converter.convert_version(model, target_opset)
|
||||
|
||||
# Apply shape inference
|
||||
print("Applying shape inference...")
|
||||
model = shape_inference.infer_shapes(model)
|
||||
|
||||
# Save processed model
|
||||
onnx.save(model, output_path)
|
||||
print(f"✓ Processed model saved to: {output_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def generate_test_data(model_path, output_path, model_name):
|
||||
"""Generate test input/output data and save as PyTorch tensors."""
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
print("\nGenerating test data...")
|
||||
|
||||
model_config = SUPPORTED_MODELS[model_name]
|
||||
seq_length = model_config['seq_length']
|
||||
|
||||
# Create reproducible test input
|
||||
np.random.seed(42)
|
||||
batch_size = 1
|
||||
|
||||
# Generate random token IDs (typical vocabulary size is 30000 for ALBERT)
|
||||
input_ids = np.random.randint(0, 30000, size=(batch_size, seq_length), dtype=np.int64)
|
||||
attention_mask = np.ones((batch_size, seq_length), dtype=np.int64)
|
||||
token_type_ids = np.zeros((batch_size, seq_length), dtype=np.int64)
|
||||
|
||||
print(f" Input shapes:")
|
||||
print(f" input_ids: {input_ids.shape}")
|
||||
print(f" attention_mask: {attention_mask.shape}")
|
||||
print(f" token_type_ids: {token_type_ids.shape}")
|
||||
|
||||
# Run inference to get output
|
||||
session = ort.InferenceSession(model_path)
|
||||
outputs = session.run(
|
||||
None,
|
||||
{
|
||||
'input_ids': input_ids,
|
||||
'attention_mask': attention_mask,
|
||||
'token_type_ids': token_type_ids,
|
||||
}
|
||||
)
|
||||
|
||||
# Save as PyTorch tensors
|
||||
test_data = {
|
||||
'input_ids': torch.from_numpy(input_ids),
|
||||
'attention_mask': torch.from_numpy(attention_mask),
|
||||
'token_type_ids': torch.from_numpy(token_type_ids),
|
||||
'last_hidden_state': torch.from_numpy(outputs[0]),
|
||||
'pooler_output': torch.from_numpy(outputs[1]),
|
||||
}
|
||||
|
||||
torch.save(test_data, output_path)
|
||||
|
||||
print(f" ✓ Test data saved to: {output_path}")
|
||||
print(f" last_hidden_state shape: {outputs[0].shape}")
|
||||
print(f" pooler_output shape: {outputs[1].shape}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='ALBERT Model Preparation Tool')
|
||||
parser.add_argument('--model', type=str, default='albert-base-v2',
|
||||
choices=list(SUPPORTED_MODELS.keys()),
|
||||
help=f'ALBERT model to download and prepare (default: albert-base-v2). Choices: {", ".join(SUPPORTED_MODELS.keys())}')
|
||||
parser.add_argument('--list', action='store_true',
|
||||
help='List all supported models')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list:
|
||||
print("Supported ALBERT models:")
|
||||
for model_id, config in SUPPORTED_MODELS.items():
|
||||
print(f" - {model_id:20s} ({config['display_name']})")
|
||||
return
|
||||
|
||||
model_name = args.model
|
||||
display_name = SUPPORTED_MODELS[model_name]['display_name']
|
||||
|
||||
print("=" * 60)
|
||||
print(f"{display_name} Model Preparation Tool")
|
||||
print("=" * 60)
|
||||
|
||||
# Setup paths
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
original_path = artifacts_dir / f"{model_name}.onnx"
|
||||
processed_path = artifacts_dir / f"{model_name}_opset16.onnx"
|
||||
test_data_path = artifacts_dir / f"{model_name}_test_data.pt"
|
||||
|
||||
# Check if we already have everything
|
||||
if processed_path.exists() and test_data_path.exists():
|
||||
print(f"\n✓ All files already exist for {display_name}:")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print("\nNothing to do!")
|
||||
return
|
||||
|
||||
# Download and convert if needed
|
||||
if not original_path.exists() and not processed_path.exists():
|
||||
print(f"\nStep 1: Downloading and converting {display_name} model...")
|
||||
download_and_convert_model(model_name, original_path)
|
||||
|
||||
# Process model if needed
|
||||
if not processed_path.exists():
|
||||
print("\nStep 2: Processing model...")
|
||||
process_model(original_path, processed_path, target_opset=16)
|
||||
|
||||
# Clean up original if we have the processed version
|
||||
if original_path.exists():
|
||||
original_path.unlink()
|
||||
|
||||
# Generate test data if needed
|
||||
if not test_data_path.exists():
|
||||
print("\nStep 3: Generating test data...")
|
||||
generate_test_data(processed_path, test_data_path, model_name)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"✓ {display_name} model preparation completed!")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠ Operation cancelled by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -1,209 +0,0 @@
|
||||
extern crate alloc;
|
||||
|
||||
use burn::module::{Initializer, Param};
|
||||
use burn::prelude::*;
|
||||
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type MyBackend = burn::backend::Wgpu;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type MyBackend = burn::backend::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub type MyBackend = burn::backend::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub type MyBackend = burn::backend::Metal;
|
||||
|
||||
// Import model info generated by build.rs (includes the albert_model module)
|
||||
include!(concat!(env!("OUT_DIR"), "/model_info.rs"));
|
||||
|
||||
// Use the albert_model module from model_info.rs
|
||||
use albert_model::Model;
|
||||
|
||||
#[derive(Debug, Module)]
|
||||
struct TestData<B: Backend> {
|
||||
input_ids: Param<Tensor<B, 2, Int>>,
|
||||
attention_mask: Param<Tensor<B, 2, Int>>,
|
||||
token_type_ids: Param<Tensor<B, 2, Int>>,
|
||||
last_hidden_state: Param<Tensor<B, 3>>,
|
||||
pooler_output: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TestData<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
use burn::module::ParamId;
|
||||
// Initialize with correct shapes matching the test data
|
||||
// ALBERT base uses sequence_length=128, hidden_size=768
|
||||
// Note: Initializer only works for float tensors, Int tensors need manual init
|
||||
Self {
|
||||
input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
|
||||
attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
|
||||
token_type_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
|
||||
last_hidden_state: Initializer::Zeros.init([1, 128, 768], device),
|
||||
pooler_output: Initializer::Zeros.init([1, 768], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_model_display_name(model_name: &str) -> &str {
|
||||
match model_name {
|
||||
"albert-base-v2" => "ALBERT Base v2",
|
||||
_ => model_name,
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// MODEL_NAME is set at build time from ALBERT_MODEL env var
|
||||
let model_name = MODEL_NAME;
|
||||
let display_name = get_model_display_name(model_name);
|
||||
|
||||
println!("========================================");
|
||||
println!("{} Burn Model Test", display_name);
|
||||
println!("========================================\n");
|
||||
|
||||
// Check if artifacts exist
|
||||
let artifacts_dir = Path::new("artifacts");
|
||||
if !artifacts_dir.exists() {
|
||||
eprintln!("Error: artifacts directory not found!");
|
||||
eprintln!("Please run get_model.py first to download the model and test data.");
|
||||
eprintln!("Example: uv run get_model.py --model {}", model_name);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Check if model files exist for this specific model
|
||||
let model_file = artifacts_dir.join(format!("{}_opset16.onnx", model_name));
|
||||
let test_data_file = artifacts_dir.join(format!("{}_test_data.pt", model_name));
|
||||
|
||||
if !model_file.exists() || !test_data_file.exists() {
|
||||
eprintln!("Error: Model files not found for {}!", display_name);
|
||||
eprintln!("Please run: uv run get_model.py --model {}", model_name);
|
||||
eprintln!();
|
||||
eprintln!("Available models:");
|
||||
eprintln!(" - albert-base-v2");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Initialize the model (without weights for now)
|
||||
println!("Initializing {} model...", display_name);
|
||||
let start = Instant::now();
|
||||
let device = Default::default();
|
||||
let model: Model<MyBackend> = Model::default();
|
||||
let init_time = start.elapsed();
|
||||
println!(" Model initialized in {:.2?}", init_time);
|
||||
|
||||
// Save model structure to file
|
||||
let model_txt_path = artifacts_dir.join(format!("{}_model.txt", model_name));
|
||||
println!(
|
||||
"\nSaving model structure to {}...",
|
||||
model_txt_path.display()
|
||||
);
|
||||
let model_str = format!("{}", model);
|
||||
std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file");
|
||||
println!(" Model structure saved");
|
||||
|
||||
// Load test data from PyTorch file
|
||||
println!("\nLoading test data from {}...", test_data_file.display());
|
||||
let start = Instant::now();
|
||||
let mut test_data = TestData::<MyBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file(&test_data_file);
|
||||
test_data.load_from(&mut store).expect("Failed to load test data");
|
||||
let load_time = start.elapsed();
|
||||
println!(" Data loaded in {:.2?}", load_time);
|
||||
|
||||
// Get the input tensors from test data
|
||||
let input_ids = test_data.input_ids.val();
|
||||
let attention_mask = test_data.attention_mask.val();
|
||||
let token_type_ids = test_data.token_type_ids.val();
|
||||
|
||||
println!(" Loaded input tensors:");
|
||||
println!(" input_ids shape: {:?}", input_ids.shape().dims);
|
||||
println!(
|
||||
" attention_mask shape: {:?}",
|
||||
attention_mask.shape().dims
|
||||
);
|
||||
println!(
|
||||
" token_type_ids shape: {:?}",
|
||||
token_type_ids.shape().dims
|
||||
);
|
||||
|
||||
// Get the reference outputs from test data
|
||||
let reference_last_hidden = test_data.last_hidden_state.val();
|
||||
let reference_pooler = test_data.pooler_output.val();
|
||||
println!(" Loaded reference outputs:");
|
||||
println!(
|
||||
" last_hidden_state shape: {:?}",
|
||||
reference_last_hidden.shape().dims
|
||||
);
|
||||
println!(
|
||||
" pooler_output shape: {:?}",
|
||||
reference_pooler.shape().dims
|
||||
);
|
||||
|
||||
// Run inference with the loaded input
|
||||
println!("\nRunning model inference with test input...");
|
||||
let start = Instant::now();
|
||||
let outputs = model.forward(input_ids, attention_mask, token_type_ids);
|
||||
let inference_time = start.elapsed();
|
||||
println!(" Inference completed in {:.2?}", inference_time);
|
||||
|
||||
// ALBERT models typically return (last_hidden_state, pooler_output)
|
||||
// The outputs tuple should have 2 elements
|
||||
println!("\n Model output shapes:");
|
||||
println!(
|
||||
" output 0 (last_hidden_state): {:?}",
|
||||
outputs.0.shape().dims
|
||||
);
|
||||
println!(" output 1 (pooler_output): {:?}", outputs.1.shape().dims);
|
||||
|
||||
// Compare outputs
|
||||
println!("\nComparing model outputs with reference data...");
|
||||
|
||||
// Compare last_hidden_state
|
||||
println!(" Checking last_hidden_state...");
|
||||
if outputs
|
||||
.0
|
||||
.clone()
|
||||
.all_close(reference_last_hidden.clone(), Some(1e-4), Some(1e-4))
|
||||
{
|
||||
println!(" ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ last_hidden_state differs from reference data!");
|
||||
|
||||
let diff = outputs.0.clone() - reference_last_hidden.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
}
|
||||
|
||||
// Compare pooler_output
|
||||
println!(" Checking pooler_output...");
|
||||
if outputs
|
||||
.1
|
||||
.clone()
|
||||
.all_close(reference_pooler.clone(), Some(1e-4), Some(1e-4))
|
||||
{
|
||||
println!(" ✓ pooler_output matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ pooler_output differs from reference data!");
|
||||
|
||||
let diff = outputs.1.clone() - reference_pooler.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
}
|
||||
|
||||
println!("\n========================================");
|
||||
println!("Model test completed!");
|
||||
println!("========================================");
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "burn-onnx-model-checks-all-minilm-l6-v2"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[workspace]
|
||||
|
||||
[features]
|
||||
default = ["tch"]
|
||||
ndarray = []
|
||||
tch = []
|
||||
wgpu = []
|
||||
metal = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../../crates/burn", features = [
|
||||
"ndarray",
|
||||
"tch",
|
||||
"wgpu",
|
||||
"metal",
|
||||
] }
|
||||
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../../../burn-onnx" }
|
||||
@@ -1,68 +0,0 @@
|
||||
# all-MiniLM-L6-v2 Model Check
|
||||
|
||||
This crate provides testing for the all-MiniLM-L6-v2 sentence transformer model with Burn.
|
||||
|
||||
## Model
|
||||
|
||||
- `all-MiniLM-L6-v2` - Sentence transformer model from HuggingFace
|
||||
(https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Download and prepare the model
|
||||
|
||||
```bash
|
||||
# Using Python directly
|
||||
python get_model.py
|
||||
|
||||
# Or using uv
|
||||
uv run get_model.py
|
||||
```
|
||||
|
||||
### 2. Build and run the model test
|
||||
|
||||
```bash
|
||||
# Build the model
|
||||
cargo build
|
||||
|
||||
# Run the test
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
all-minilm-l6-v2/
|
||||
├── artifacts/ # Downloaded ONNX model and test data
|
||||
│ ├── all-minilm-l6-v2_opset16.onnx
|
||||
│ ├── test_data.pt
|
||||
│ └── model-python.txt
|
||||
├── src/
|
||||
│ └── main.rs # Test runner
|
||||
├── build.rs # Build script that generates model code
|
||||
├── get_model.py # Model download and preparation script
|
||||
├── Cargo.toml
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Model Architecture
|
||||
|
||||
all-MiniLM-L6-v2 is a sentence transformer model based on Microsoft's MiniLM architecture. It maps
|
||||
sentences and paragraphs to a 384-dimensional dense vector space and is commonly used for semantic
|
||||
search, clustering, and similarity tasks.
|
||||
|
||||
- **Inputs**:
|
||||
- `input_ids`: Token IDs (shape: [batch_size, sequence_length])
|
||||
- `attention_mask`: Attention mask (shape: [batch_size, sequence_length])
|
||||
- `token_type_ids`: Token type IDs (shape: [batch_size, sequence_length])
|
||||
|
||||
- **Outputs**:
|
||||
- `last_hidden_state`: Sequence of hidden states (shape: [batch_size, sequence_length, 384])
|
||||
|
||||
## Notes
|
||||
|
||||
- The default sequence length is 128 tokens
|
||||
- The model has a hidden size of 384
|
||||
- The model uses ONNX opset 16
|
||||
- Test data is generated with random inputs for reproducibility (seed=42)
|
||||
- For sentence embeddings, you typically use mean pooling on the last_hidden_state
|
||||
@@ -1,32 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
use std::path::Path;
|
||||
|
||||
fn main() {
|
||||
let onnx_path = "artifacts/all-minilm-l6-v2_opset16.onnx";
|
||||
let test_data_path = "artifacts/test_data.pt";
|
||||
|
||||
// Tell Cargo to only rebuild if these files change
|
||||
println!("cargo:rerun-if-changed={}", onnx_path);
|
||||
println!("cargo:rerun-if-changed={}", test_data_path);
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
// Check if the ONNX model file exists
|
||||
if !Path::new(onnx_path).exists() {
|
||||
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
|
||||
eprintln!();
|
||||
eprintln!("Please run the following command to download and prepare the model:");
|
||||
eprintln!(" python get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("Or if you prefer using uv:");
|
||||
eprintln!(" uv run get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("This will download the all-MiniLM-L6-v2 model and convert it to ONNX format.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Generate the model code from the ONNX file
|
||||
ModelGen::new()
|
||||
.input(onnx_path)
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
}
|
||||
@@ -1,316 +0,0 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "onnx==1.19.0",
|
||||
# "onnxruntime>=1.22.0",
|
||||
# "huggingface-hub>=0.20.0",
|
||||
# "numpy",
|
||||
# "torch",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import os
|
||||
import sys
|
||||
import onnx
|
||||
from onnx import shape_inference, version_converter
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def download_minilm_model(output_path):
|
||||
"""Download all-MiniLM-L6-v2 model from Hugging Face."""
|
||||
print("Downloading all-MiniLM-L6-v2 model from Hugging Face...")
|
||||
|
||||
# Download the ONNX model from Hugging Face
|
||||
model_path = hf_hub_download(
|
||||
repo_id="Xenova/all-MiniLM-L6-v2",
|
||||
filename="onnx/model.onnx",
|
||||
cache_dir="./artifacts/cache",
|
||||
)
|
||||
|
||||
# Copy to artifacts
|
||||
import shutil
|
||||
|
||||
shutil.copy(model_path, output_path)
|
||||
|
||||
if not output_path.exists():
|
||||
raise FileNotFoundError(f"Failed to download ONNX file to {output_path}")
|
||||
|
||||
print(f"✓ Model downloaded to: {output_path}")
|
||||
|
||||
|
||||
def process_model(input_path, output_path, target_opset=16):
|
||||
"""Load, upgrade opset, and apply shape inference to model."""
|
||||
print(f"Loading model from {input_path}...")
|
||||
model = onnx.load(input_path)
|
||||
|
||||
# Check and upgrade opset if needed
|
||||
current_opset = model.opset_import[0].version
|
||||
if current_opset < target_opset:
|
||||
print(f"Upgrading opset from {current_opset} to {target_opset}...")
|
||||
model = version_converter.convert_version(model, target_opset)
|
||||
|
||||
# Apply shape inference
|
||||
print("Applying shape inference...")
|
||||
model = shape_inference.infer_shapes(model)
|
||||
|
||||
# Save processed model
|
||||
onnx.save(model, output_path)
|
||||
print(f"✓ Processed model saved to: {output_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_input_info(model):
|
||||
"""Extract input information from ONNX model."""
|
||||
inputs = []
|
||||
for input_info in model.graph.input:
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
# Use proper defaults for sentence transformers
|
||||
if (
|
||||
"input_ids" in input_info.name
|
||||
or "attention_mask" in input_info.name
|
||||
or "token_type_ids" in input_info.name
|
||||
):
|
||||
# Default sequence length for all text inputs
|
||||
shape.append(1 if len(shape) == 0 else 128)
|
||||
else:
|
||||
shape.append(1) # Default to 1 for other dynamic dimensions
|
||||
inputs.append(
|
||||
{
|
||||
"name": input_info.name,
|
||||
"shape": shape,
|
||||
"dtype": input_info.type.tensor_type.elem_type,
|
||||
}
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def generate_test_data(model_path, output_dir):
|
||||
"""Generate test input/output data and save as PyTorch tensors."""
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
print("\nGenerating test data...")
|
||||
|
||||
# Load model to get input shapes
|
||||
model = onnx.load(model_path)
|
||||
input_infos = get_input_info(model)
|
||||
|
||||
print(f" Model has {len(input_infos)} inputs:")
|
||||
for info in input_infos:
|
||||
print(f" - {info['name']}: shape={info['shape']}, dtype={info['dtype']}")
|
||||
|
||||
# Create reproducible test inputs
|
||||
np.random.seed(42)
|
||||
test_inputs = {}
|
||||
|
||||
for info in input_infos:
|
||||
if info["dtype"] == onnx.TensorProto.INT64:
|
||||
# Handle different integer inputs appropriately
|
||||
if "attention_mask" in info["name"]:
|
||||
# Attention mask should be 1 for valid tokens, 0 for padding
|
||||
# For testing, use all 1s (all valid tokens)
|
||||
test_input = np.ones(info["shape"], dtype=np.int64)
|
||||
elif "token_type_ids" in info["name"]:
|
||||
# Token type IDs should be 0 or 1 (typically 0 for single-sequence tasks)
|
||||
test_input = np.zeros(info["shape"], dtype=np.int64)
|
||||
elif "input_ids" in info["name"]:
|
||||
# For input_ids, use random integers in vocabulary range
|
||||
test_input = np.random.randint(0, 30522, size=info["shape"], dtype=np.int64)
|
||||
else:
|
||||
# For other INT64 inputs, use random integers in vocabulary range
|
||||
test_input = np.random.randint(0, 30522, size=info["shape"], dtype=np.int64)
|
||||
else:
|
||||
# For float inputs, use random floats
|
||||
test_input = np.random.rand(*info["shape"]).astype(np.float32)
|
||||
test_inputs[info["name"]] = test_input
|
||||
|
||||
# Run inference to get output
|
||||
session = ort.InferenceSession(model_path)
|
||||
outputs = session.run(None, test_inputs)
|
||||
|
||||
# Save in a format that's easier to load in Rust
|
||||
# For all-MiniLM-L6-v2, we expect:
|
||||
# - Inputs: input_ids, attention_mask, token_type_ids
|
||||
# - Outputs: last_hidden_state (3D)
|
||||
|
||||
# Compute mean pooled embeddings (what sentence-transformers uses)
|
||||
last_hidden_state = outputs[0]
|
||||
attention_mask_np = test_inputs.get("attention_mask")
|
||||
|
||||
# Mean pooling - take attention mask into account for correct averaging
|
||||
input_mask_expanded = np.expand_dims(attention_mask_np, axis=-1).astype(np.float32)
|
||||
sum_embeddings = np.sum(last_hidden_state * input_mask_expanded, axis=1)
|
||||
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
|
||||
pooled_embeddings = sum_embeddings / sum_mask
|
||||
|
||||
# Create a more structured format for Rust
|
||||
test_data = {
|
||||
"input_ids": torch.from_numpy(
|
||||
test_inputs.get(
|
||||
"input_ids", test_inputs.get("inputs.0", list(test_inputs.values())[0])
|
||||
)
|
||||
),
|
||||
"attention_mask": torch.from_numpy(
|
||||
test_inputs.get(
|
||||
"attention_mask",
|
||||
test_inputs.get("inputs.1", list(test_inputs.values())[1]),
|
||||
)
|
||||
),
|
||||
"token_type_ids": torch.from_numpy(
|
||||
test_inputs.get(
|
||||
"token_type_ids",
|
||||
test_inputs.get("inputs.2", np.zeros((1, 128), dtype=np.int64)),
|
||||
)
|
||||
),
|
||||
"last_hidden_state": torch.from_numpy(outputs[0]),
|
||||
"pooled_embeddings": torch.from_numpy(pooled_embeddings),
|
||||
}
|
||||
|
||||
test_data_path = Path(output_dir) / "test_data.pt"
|
||||
torch.save(test_data, test_data_path)
|
||||
|
||||
print(f" ✓ Test data saved to: {test_data_path}")
|
||||
print(f" Input shapes:")
|
||||
print(f" input_ids: {test_data['input_ids'].shape}")
|
||||
print(f" attention_mask: {test_data['attention_mask'].shape}")
|
||||
print(f" token_type_ids: {test_data['token_type_ids'].shape}")
|
||||
print(f" Output shapes:")
|
||||
print(f" last_hidden_state: {test_data['last_hidden_state'].shape}")
|
||||
print(f" pooled_embeddings: {test_data['pooled_embeddings'].shape}")
|
||||
|
||||
|
||||
def save_model_info(model_path, output_dir):
|
||||
"""Save model structure information to a text file."""
|
||||
print("\nSaving model information...")
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
info_path = Path(output_dir) / "model-python.txt"
|
||||
with open(info_path, "w") as f:
|
||||
f.write("all-MiniLM-L6-v2 Model Information\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
# Input information
|
||||
f.write("Inputs:\n")
|
||||
for input_info in model.graph.input:
|
||||
f.write(f" - {input_info.name}\n")
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append("dynamic")
|
||||
f.write(f" Shape: {shape}\n")
|
||||
f.write(
|
||||
f" Type: {onnx.TensorProto.DataType.Name(input_info.type.tensor_type.elem_type)}\n"
|
||||
)
|
||||
|
||||
# Output information
|
||||
f.write("\nOutputs:\n")
|
||||
for output_info in model.graph.output:
|
||||
f.write(f" - {output_info.name}\n")
|
||||
shape = []
|
||||
for dim in output_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append("dynamic")
|
||||
f.write(f" Shape: {shape}\n")
|
||||
f.write(
|
||||
f" Type: {onnx.TensorProto.DataType.Name(output_info.type.tensor_type.elem_type)}\n"
|
||||
)
|
||||
|
||||
# Model statistics
|
||||
f.write(f"\nModel Statistics:\n")
|
||||
f.write(f" Opset version: {model.opset_import[0].version}\n")
|
||||
f.write(f" Number of nodes: {len(model.graph.node)}\n")
|
||||
f.write(f" Number of initializers: {len(model.graph.initializer)}\n")
|
||||
|
||||
# Node types summary
|
||||
node_types = {}
|
||||
for node in model.graph.node:
|
||||
op_type = node.op_type
|
||||
node_types[op_type] = node_types.get(op_type, 0) + 1
|
||||
|
||||
f.write(f"\nNode types:\n")
|
||||
for op_type, count in sorted(node_types.items()):
|
||||
f.write(f" {op_type}: {count}\n")
|
||||
|
||||
print(f" ✓ Model info saved to: {info_path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("all-MiniLM-L6-v2 Model Preparation Tool")
|
||||
print("=" * 60)
|
||||
|
||||
# Setup paths
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
original_path = artifacts_dir / "all-minilm-l6-v2.onnx"
|
||||
processed_path = artifacts_dir / "all-minilm-l6-v2_opset16.onnx"
|
||||
test_data_path = artifacts_dir / "test_data.pt"
|
||||
model_info_path = artifacts_dir / "model-python.txt"
|
||||
|
||||
# Check if we already have everything
|
||||
if processed_path.exists() and test_data_path.exists() and model_info_path.exists():
|
||||
print(f"\n✓ All files already exist:")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print(f" Model info: {model_info_path}")
|
||||
print("\nNothing to do!")
|
||||
return
|
||||
|
||||
# Download model if needed
|
||||
if not original_path.exists() and not processed_path.exists():
|
||||
print("\nStep 1: Downloading all-MiniLM-L6-v2 model...")
|
||||
download_minilm_model(original_path)
|
||||
|
||||
# Process model if needed
|
||||
if not processed_path.exists():
|
||||
print("\nStep 2: Processing model...")
|
||||
process_model(original_path, processed_path, target_opset=16)
|
||||
|
||||
# Clean up original if we have the processed version
|
||||
if original_path.exists() and processed_path.exists():
|
||||
original_path.unlink()
|
||||
|
||||
# Generate test data if needed
|
||||
if not test_data_path.exists():
|
||||
print("\nStep 3: Generating test data...")
|
||||
generate_test_data(processed_path, artifacts_dir)
|
||||
|
||||
# Save model info if needed
|
||||
if not model_info_path.exists():
|
||||
print("\nStep 4: Saving model information...")
|
||||
save_model_info(processed_path, artifacts_dir)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ all-MiniLM-L6-v2 model preparation completed!")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print(f" Model info: {model_info_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠ Operation cancelled by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -1,284 +0,0 @@
|
||||
extern crate alloc;
|
||||
|
||||
use burn::module::{Initializer, Param};
|
||||
use burn::prelude::*;
|
||||
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type MyBackend = burn::backend::Wgpu;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type MyBackend = burn::backend::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub type MyBackend = burn::backend::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub type MyBackend = burn::backend::Metal;
|
||||
|
||||
// Import the generated model code as a module
|
||||
pub mod all_minilm_l6_v2 {
|
||||
include!(concat!(
|
||||
env!("OUT_DIR"),
|
||||
"/model/all-minilm-l6-v2_opset16.rs"
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Debug, Module)]
|
||||
struct TestData<B: Backend> {
|
||||
input_ids: Param<Tensor<B, 2, Int>>,
|
||||
attention_mask: Param<Tensor<B, 2, Int>>,
|
||||
token_type_ids: Param<Tensor<B, 2, Int>>,
|
||||
last_hidden_state: Param<Tensor<B, 3>>,
|
||||
pooled_embeddings: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TestData<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
use burn::module::ParamId;
|
||||
// Initialize with correct shapes matching the test data
|
||||
// Note: Initializer only works for float tensors, Int tensors need manual init
|
||||
Self {
|
||||
input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
|
||||
attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
|
||||
token_type_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
|
||||
last_hidden_state: Initializer::Zeros.init([1, 128, 384], device),
|
||||
pooled_embeddings: Initializer::Zeros.init([1, 384], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply mean pooling to get sentence embeddings
|
||||
fn mean_pool<B: Backend>(
|
||||
last_hidden_state: Tensor<B, 3>,
|
||||
attention_mask: Tensor<B, 2, Int>,
|
||||
) -> Tensor<B, 2> {
|
||||
// Convert attention_mask to float and expand dimensions to match hidden_state
|
||||
let attention_mask_float = attention_mask.float().unsqueeze_dim::<3>(2);
|
||||
|
||||
// Multiply hidden states by attention mask
|
||||
let masked_embeddings = last_hidden_state * attention_mask_float.clone();
|
||||
|
||||
// Sum along sequence dimension (dim 1)
|
||||
let sum_embeddings = masked_embeddings.sum_dim(1);
|
||||
|
||||
// Sum attention mask to get count of non-padding tokens
|
||||
let sum_mask = attention_mask_float.sum_dim(1).clamp_min(1e-9);
|
||||
|
||||
// Divide to get mean - result is [batch, 1, hidden]
|
||||
let pooled = sum_embeddings / sum_mask;
|
||||
|
||||
// Get the shape to reshape to [batch, hidden]
|
||||
let shape = pooled.shape();
|
||||
let batch_size = shape.dims[0];
|
||||
let hidden_size = shape.dims[2];
|
||||
|
||||
pooled.reshape([batch_size, hidden_size])
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("========================================");
|
||||
println!("all-MiniLM-L6-v2 Burn Model Test");
|
||||
println!("========================================\n");
|
||||
|
||||
// Check if artifacts exist
|
||||
let artifacts_dir = Path::new("artifacts");
|
||||
if !artifacts_dir.exists() {
|
||||
eprintln!("Error: artifacts directory not found!");
|
||||
eprintln!("Please run get_model.py first to download the model and test data.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Initialize the model (using default which includes the converted weights)
|
||||
println!("Initializing all-MiniLM-L6-v2 model...");
|
||||
let start = Instant::now();
|
||||
let device = Default::default();
|
||||
let model: all_minilm_l6_v2::Model<MyBackend> = all_minilm_l6_v2::Model::default();
|
||||
let init_time = start.elapsed();
|
||||
println!(" Model initialized in {:.2?}", init_time);
|
||||
|
||||
// Save model structure to file
|
||||
println!("\nSaving model structure to artifacts/model.txt...");
|
||||
let model_str = format!("{}", model);
|
||||
std::fs::write("artifacts/model.txt", &model_str)
|
||||
.expect("Failed to write model structure to file");
|
||||
println!(" Model structure saved");
|
||||
|
||||
// Load test data from PyTorch file
|
||||
println!("\nLoading test data from artifacts/test_data.pt...");
|
||||
let start = Instant::now();
|
||||
let mut test_data = TestData::<MyBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file("artifacts/test_data.pt");
|
||||
test_data.load_from(&mut store).expect("Failed to load test data");
|
||||
let load_time = start.elapsed();
|
||||
println!(" Data loaded in {:.2?}", load_time);
|
||||
|
||||
// Get the input tensors from test data
|
||||
let input_ids = test_data.input_ids.val();
|
||||
let attention_mask = test_data.attention_mask.val();
|
||||
let token_type_ids = test_data.token_type_ids.val();
|
||||
let input_ids_shape = input_ids.shape();
|
||||
let attention_mask_shape = attention_mask.shape();
|
||||
let token_type_ids_shape = token_type_ids.shape();
|
||||
println!(" Loaded input_ids with shape: {:?}", input_ids_shape.dims);
|
||||
println!(
|
||||
" Loaded attention_mask with shape: {:?}",
|
||||
attention_mask_shape.dims
|
||||
);
|
||||
println!(
|
||||
" Loaded token_type_ids with shape: {:?}",
|
||||
token_type_ids_shape.dims
|
||||
);
|
||||
|
||||
// Get the reference outputs from test data
|
||||
let reference_last_hidden_state = test_data.last_hidden_state.val();
|
||||
let reference_pooled_embeddings = test_data.pooled_embeddings.val();
|
||||
let ref_last_hidden_shape = reference_last_hidden_state.shape();
|
||||
let ref_pooled_shape = reference_pooled_embeddings.shape();
|
||||
println!(
|
||||
" Loaded reference last_hidden_state with shape: {:?}",
|
||||
ref_last_hidden_shape.dims
|
||||
);
|
||||
println!(
|
||||
" Loaded reference pooled_embeddings with shape: {:?}",
|
||||
ref_pooled_shape.dims
|
||||
);
|
||||
|
||||
// Run inference with the loaded input
|
||||
println!("\nRunning model inference with test input...");
|
||||
let start = Instant::now();
|
||||
|
||||
let last_hidden_state = model.forward(
|
||||
input_ids.clone(),
|
||||
attention_mask.clone(),
|
||||
token_type_ids.clone(),
|
||||
);
|
||||
|
||||
let inference_time = start.elapsed();
|
||||
println!(" Inference completed in {:.2?}", inference_time);
|
||||
|
||||
// Compute pooled embeddings (mean pooling)
|
||||
println!("\nComputing pooled embeddings (mean pooling)...");
|
||||
let start = Instant::now();
|
||||
let pooled_embeddings = mean_pool(last_hidden_state.clone(), attention_mask.clone());
|
||||
let pooling_time = start.elapsed();
|
||||
println!(" Pooling completed in {:.2?}", pooling_time);
|
||||
|
||||
// Display output shapes
|
||||
let last_hidden_shape = last_hidden_state.shape();
|
||||
let pooled_shape = pooled_embeddings.shape();
|
||||
println!("\n Model output shapes:");
|
||||
println!(" last_hidden_state: {:?}", last_hidden_shape.dims);
|
||||
println!(" pooled_embeddings: {:?}", pooled_shape.dims);
|
||||
|
||||
// Verify expected output shapes match
|
||||
if last_hidden_shape.dims == ref_last_hidden_shape.dims {
|
||||
println!(
|
||||
" ✓ last_hidden_state shape matches expected: {:?}",
|
||||
ref_last_hidden_shape.dims
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" ⚠ Warning: Expected last_hidden_state shape {:?}, got {:?}",
|
||||
ref_last_hidden_shape.dims, last_hidden_shape.dims
|
||||
);
|
||||
}
|
||||
|
||||
if pooled_shape.dims == ref_pooled_shape.dims {
|
||||
println!(
|
||||
" ✓ pooled_embeddings shape matches expected: {:?}",
|
||||
ref_pooled_shape.dims
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" ⚠ Warning: Expected pooled_embeddings shape {:?}, got {:?}",
|
||||
ref_pooled_shape.dims, pooled_shape.dims
|
||||
);
|
||||
}
|
||||
|
||||
// Compare outputs
|
||||
println!("\nComparing model outputs with reference data...");
|
||||
|
||||
// Check if last_hidden_state is close
|
||||
println!("\n Checking last_hidden_state:");
|
||||
if last_hidden_state.clone().all_close(
|
||||
reference_last_hidden_state.clone(),
|
||||
Some(1e-4),
|
||||
Some(1e-4),
|
||||
) {
|
||||
println!(" ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ last_hidden_state differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = last_hidden_state.clone() - reference_last_hidden_state.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = last_hidden_state.clone().flatten::<1>(0, 2);
|
||||
let reference_flat = reference_last_hidden_state.clone().flatten::<1>(0, 2);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if pooled_embeddings is close
|
||||
println!("\n Checking pooled_embeddings:");
|
||||
if pooled_embeddings.clone().all_close(
|
||||
reference_pooled_embeddings.clone(),
|
||||
Some(1e-4),
|
||||
Some(1e-4),
|
||||
) {
|
||||
println!(" ✓ pooled_embeddings matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ pooled_embeddings differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = pooled_embeddings.clone() - reference_pooled_embeddings.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = pooled_embeddings.clone().flatten::<1>(0, 1);
|
||||
let reference_flat = reference_pooled_embeddings.clone().flatten::<1>(0, 1);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n========================================");
|
||||
println!("Model test completed!");
|
||||
println!("========================================");
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "burn-onnx-model-checks-clip-vit-b-32-text"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[workspace]
|
||||
|
||||
[features]
|
||||
default = ["tch"]
|
||||
ndarray = []
|
||||
tch = []
|
||||
wgpu = []
|
||||
metal = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../../crates/burn", features = [
|
||||
"ndarray",
|
||||
"tch",
|
||||
"wgpu",
|
||||
"metal",
|
||||
] }
|
||||
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../../../burn-onnx" }
|
||||
@@ -1,33 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
use std::path::Path;
|
||||
|
||||
fn main() {
|
||||
let onnx_path = "artifacts/clip-vit-b-32-text_opset16.onnx";
|
||||
let test_data_path = "artifacts/test_data.pt";
|
||||
|
||||
// Tell Cargo to only rebuild if these files change
|
||||
println!("cargo:rerun-if-changed={}", onnx_path);
|
||||
println!("cargo:rerun-if-changed={}", test_data_path);
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
// Check if the ONNX model file exists
|
||||
if !Path::new(onnx_path).exists() {
|
||||
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
|
||||
eprintln!();
|
||||
eprintln!("Please run the following command to download and prepare the model:");
|
||||
eprintln!(" python get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("Or if you prefer using uv:");
|
||||
eprintln!(" uv run get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("This will download the CLIP ViT-B-32-text model and convert it to ONNX format.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Generate the model code from the ONNX file
|
||||
// Use double precision to handle large Int64 constants in CLIP
|
||||
ModelGen::new()
|
||||
.input(onnx_path)
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
}
|
||||
@@ -1,288 +0,0 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "onnx==1.19.0",
|
||||
# "onnxruntime>=1.22.0",
|
||||
# "huggingface-hub>=0.20.0",
|
||||
# "numpy",
|
||||
# "torch",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import os
|
||||
import sys
|
||||
import onnx
|
||||
from onnx import shape_inference, version_converter
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def download_clip_model(output_path):
|
||||
"""Download CLIP ViT-B-32-text model from Hugging Face."""
|
||||
print("Downloading CLIP ViT-B-32-text model from Hugging Face...")
|
||||
|
||||
# Download the ONNX model from Hugging Face
|
||||
model_path = hf_hub_download(
|
||||
repo_id="Qdrant/clip-ViT-B-32-text",
|
||||
filename="model.onnx",
|
||||
cache_dir="./artifacts/cache",
|
||||
)
|
||||
|
||||
# Copy to artifacts
|
||||
import shutil
|
||||
|
||||
shutil.copy(model_path, output_path)
|
||||
|
||||
if not output_path.exists():
|
||||
raise FileNotFoundError(f"Failed to download ONNX file to {output_path}")
|
||||
|
||||
print(f"✓ Model downloaded to: {output_path}")
|
||||
|
||||
|
||||
def process_model(input_path, output_path, target_opset=16):
|
||||
"""Load, upgrade opset, and apply shape inference to model."""
|
||||
print(f"Loading model from {input_path}...")
|
||||
model = onnx.load(input_path)
|
||||
|
||||
# Check and upgrade opset if needed
|
||||
current_opset = model.opset_import[0].version
|
||||
if current_opset < target_opset:
|
||||
print(f"Upgrading opset from {current_opset} to {target_opset}...")
|
||||
model = version_converter.convert_version(model, target_opset)
|
||||
|
||||
# Apply shape inference
|
||||
print("Applying shape inference...")
|
||||
model = shape_inference.infer_shapes(model)
|
||||
|
||||
# Save processed model
|
||||
onnx.save(model, output_path)
|
||||
print(f"✓ Processed model saved to: {output_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_input_info(model):
|
||||
"""Extract input information from ONNX model."""
|
||||
inputs = []
|
||||
for input_info in model.graph.input:
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
# Use proper defaults for CLIP model
|
||||
if (
|
||||
"input_ids" in input_info.name
|
||||
or "attention_mask" in input_info.name
|
||||
):
|
||||
# CLIP uses sequence length of 77
|
||||
shape.append(1 if len(shape) == 0 else 77)
|
||||
else:
|
||||
shape.append(1) # Default to 1 for other dynamic dimensions
|
||||
inputs.append(
|
||||
{
|
||||
"name": input_info.name,
|
||||
"shape": shape,
|
||||
"dtype": input_info.type.tensor_type.elem_type,
|
||||
}
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def generate_test_data(model_path, output_dir):
|
||||
"""Generate test input/output data and save as PyTorch tensors."""
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
print("\nGenerating test data...")
|
||||
|
||||
# Load model to get input shapes
|
||||
model = onnx.load(model_path)
|
||||
input_infos = get_input_info(model)
|
||||
|
||||
print(f" Model has {len(input_infos)} inputs:")
|
||||
for info in input_infos:
|
||||
print(f" - {info['name']}: shape={info['shape']}, dtype={info['dtype']}")
|
||||
|
||||
# Create reproducible test inputs
|
||||
np.random.seed(42)
|
||||
test_inputs = {}
|
||||
|
||||
for info in input_infos:
|
||||
if info["dtype"] == onnx.TensorProto.INT64:
|
||||
# For INT64 inputs (like input_ids), use random integers
|
||||
test_input = np.random.randint(0, 1000, size=info["shape"], dtype=np.int64)
|
||||
else:
|
||||
# For float inputs, use random floats
|
||||
test_input = np.random.rand(*info["shape"]).astype(np.float32)
|
||||
test_inputs[info["name"]] = test_input
|
||||
|
||||
# Run inference to get output
|
||||
session = ort.InferenceSession(model_path)
|
||||
outputs = session.run(None, test_inputs)
|
||||
|
||||
# Save in a format that's easier to load in Rust
|
||||
# For CLIP, we expect:
|
||||
# - Inputs: input_ids, attention_mask
|
||||
# - Outputs: text_embeds (2D), last_hidden_state (3D)
|
||||
|
||||
# Create a more structured format for Rust
|
||||
test_data = {
|
||||
"input_ids": torch.from_numpy(
|
||||
test_inputs.get(
|
||||
"input_ids", test_inputs.get("inputs.0", list(test_inputs.values())[0])
|
||||
)
|
||||
),
|
||||
"attention_mask": torch.from_numpy(
|
||||
test_inputs.get(
|
||||
"attention_mask",
|
||||
test_inputs.get("inputs.1", list(test_inputs.values())[1]),
|
||||
)
|
||||
),
|
||||
"text_embeds": torch.from_numpy(outputs[0]),
|
||||
"last_hidden_state": torch.from_numpy(outputs[1])
|
||||
if len(outputs) > 1
|
||||
else torch.zeros(1, 77, 512),
|
||||
}
|
||||
|
||||
test_data_path = Path(output_dir) / "test_data.pt"
|
||||
torch.save(test_data, test_data_path)
|
||||
|
||||
print(f" ✓ Test data saved to: {test_data_path}")
|
||||
print(f" Input shapes:")
|
||||
print(f" input_ids: {test_data['input_ids'].shape}")
|
||||
print(f" attention_mask: {test_data['attention_mask'].shape}")
|
||||
print(f" Output shapes:")
|
||||
print(f" text_embeds: {test_data['text_embeds'].shape}")
|
||||
print(f" last_hidden_state: {test_data['last_hidden_state'].shape}")
|
||||
|
||||
|
||||
def save_model_info(model_path, output_dir):
|
||||
"""Save model structure information to a text file."""
|
||||
print("\nSaving model information...")
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
info_path = Path(output_dir) / "model-python.txt"
|
||||
with open(info_path, "w") as f:
|
||||
f.write("CLIP ViT-B-32-text Model Information\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
# Input information
|
||||
f.write("Inputs:\n")
|
||||
for input_info in model.graph.input:
|
||||
f.write(f" - {input_info.name}\n")
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append("dynamic")
|
||||
f.write(f" Shape: {shape}\n")
|
||||
f.write(
|
||||
f" Type: {onnx.TensorProto.DataType.Name(input_info.type.tensor_type.elem_type)}\n"
|
||||
)
|
||||
|
||||
# Output information
|
||||
f.write("\nOutputs:\n")
|
||||
for output_info in model.graph.output:
|
||||
f.write(f" - {output_info.name}\n")
|
||||
shape = []
|
||||
for dim in output_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append("dynamic")
|
||||
f.write(f" Shape: {shape}\n")
|
||||
f.write(
|
||||
f" Type: {onnx.TensorProto.DataType.Name(output_info.type.tensor_type.elem_type)}\n"
|
||||
)
|
||||
|
||||
# Model statistics
|
||||
f.write(f"\nModel Statistics:\n")
|
||||
f.write(f" Opset version: {model.opset_import[0].version}\n")
|
||||
f.write(f" Number of nodes: {len(model.graph.node)}\n")
|
||||
f.write(f" Number of initializers: {len(model.graph.initializer)}\n")
|
||||
|
||||
# Node types summary
|
||||
node_types = {}
|
||||
for node in model.graph.node:
|
||||
op_type = node.op_type
|
||||
node_types[op_type] = node_types.get(op_type, 0) + 1
|
||||
|
||||
f.write(f"\nNode types:\n")
|
||||
for op_type, count in sorted(node_types.items()):
|
||||
f.write(f" {op_type}: {count}\n")
|
||||
|
||||
print(f" ✓ Model info saved to: {info_path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("CLIP ViT-B-32-text Model Preparation Tool")
|
||||
print("=" * 60)
|
||||
|
||||
# Setup paths
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
original_path = artifacts_dir / "clip-vit-b-32-text.onnx"
|
||||
processed_path = artifacts_dir / "clip-vit-b-32-text_opset16.onnx"
|
||||
test_data_path = artifacts_dir / "test_data.pt"
|
||||
model_info_path = artifacts_dir / "model-python.txt"
|
||||
|
||||
# Check if we already have everything
|
||||
if processed_path.exists() and test_data_path.exists() and model_info_path.exists():
|
||||
print(f"\n✓ All files already exist:")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print(f" Model info: {model_info_path}")
|
||||
print("\nNothing to do!")
|
||||
return
|
||||
|
||||
# Download model if needed
|
||||
if not original_path.exists() and not processed_path.exists():
|
||||
print("\nStep 1: Downloading CLIP model...")
|
||||
download_clip_model(original_path)
|
||||
|
||||
# Process model if needed
|
||||
if not processed_path.exists():
|
||||
print("\nStep 2: Processing model...")
|
||||
process_model(original_path, processed_path, target_opset=16)
|
||||
|
||||
# Clean up original if we have the processed version
|
||||
if original_path.exists() and processed_path.exists():
|
||||
original_path.unlink()
|
||||
|
||||
# Generate test data if needed
|
||||
if not test_data_path.exists():
|
||||
print("\nStep 3: Generating test data...")
|
||||
generate_test_data(processed_path, artifacts_dir)
|
||||
|
||||
# Save model info if needed
|
||||
if not model_info_path.exists():
|
||||
print("\nStep 4: Saving model information...")
|
||||
save_model_info(processed_path, artifacts_dir)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ CLIP model preparation completed!")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print(f" Model info: {model_info_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠ Operation cancelled by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -1,236 +0,0 @@
|
||||
extern crate alloc;
|
||||
|
||||
use burn::module::{Initializer, Param};
|
||||
use burn::prelude::*;
|
||||
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type MyBackend = burn::backend::Wgpu;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type MyBackend = burn::backend::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub type MyBackend = burn::backend::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub type MyBackend = burn::backend::Metal;
|
||||
|
||||
// Import the generated model code as a module
|
||||
pub mod clip_vit_b_32_text {
|
||||
include!(concat!(
|
||||
env!("OUT_DIR"),
|
||||
"/model/clip-vit-b-32-text_opset16.rs"
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Debug, Module)]
|
||||
struct TestData<B: Backend> {
|
||||
input_ids: Param<Tensor<B, 2, Int>>,
|
||||
attention_mask: Param<Tensor<B, 2, Int>>,
|
||||
text_embeds: Param<Tensor<B, 2>>,
|
||||
last_hidden_state: Param<Tensor<B, 3>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TestData<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
use burn::module::ParamId;
|
||||
// CLIP ViT-B-32 text: sequence_length=77, embed_dim=512
|
||||
// Note: Initializer only works for float tensors, Int tensors need manual init
|
||||
Self {
|
||||
input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 77], device)),
|
||||
attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 77], device)),
|
||||
text_embeds: Initializer::Zeros.init([1, 512], device),
|
||||
last_hidden_state: Initializer::Zeros.init([1, 77, 512], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("========================================");
|
||||
println!("CLIP ViT-B-32-text Burn Model Test");
|
||||
println!("========================================\n");
|
||||
|
||||
// Check if artifacts exist
|
||||
let artifacts_dir = Path::new("artifacts");
|
||||
if !artifacts_dir.exists() {
|
||||
eprintln!("Error: artifacts directory not found!");
|
||||
eprintln!("Please run get_model.py first to download the model and test data.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Initialize the model (using default which includes the converted weights)
|
||||
println!("Initializing CLIP model...");
|
||||
let start = Instant::now();
|
||||
let device = Default::default();
|
||||
let model: clip_vit_b_32_text::Model<MyBackend> = clip_vit_b_32_text::Model::default();
|
||||
let init_time = start.elapsed();
|
||||
println!(" Model initialized in {:.2?}", init_time);
|
||||
|
||||
// Save model structure to file
|
||||
println!("\nSaving model structure to artifacts/model.txt...");
|
||||
let model_str = format!("{}", model);
|
||||
std::fs::write("artifacts/model.txt", &model_str)
|
||||
.expect("Failed to write model structure to file");
|
||||
println!(" Model structure saved");
|
||||
|
||||
// Load test data from PyTorch file
|
||||
println!("\nLoading test data from artifacts/test_data.pt...");
|
||||
let start = Instant::now();
|
||||
let mut test_data = TestData::<MyBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file("artifacts/test_data.pt");
|
||||
test_data.load_from(&mut store).expect("Failed to load test data");
|
||||
let load_time = start.elapsed();
|
||||
println!(" Data loaded in {:.2?}", load_time);
|
||||
|
||||
// Get the input tensors from test data
|
||||
let input_ids = test_data.input_ids.val();
|
||||
let attention_mask = test_data.attention_mask.val();
|
||||
let input_ids_shape = input_ids.shape();
|
||||
let attention_mask_shape = attention_mask.shape();
|
||||
println!(" Loaded input_ids with shape: {:?}", input_ids_shape.dims);
|
||||
println!(
|
||||
" Loaded attention_mask with shape: {:?}",
|
||||
attention_mask_shape.dims
|
||||
);
|
||||
|
||||
// Get the reference outputs from test data
|
||||
let reference_text_embeds = test_data.text_embeds.val();
|
||||
let reference_last_hidden_state = test_data.last_hidden_state.val();
|
||||
let ref_text_embeds_shape = reference_text_embeds.shape();
|
||||
let ref_last_hidden_shape = reference_last_hidden_state.shape();
|
||||
println!(
|
||||
" Loaded reference text_embeds with shape: {:?}",
|
||||
ref_text_embeds_shape.dims
|
||||
);
|
||||
println!(
|
||||
" Loaded reference last_hidden_state with shape: {:?}",
|
||||
ref_last_hidden_shape.dims
|
||||
);
|
||||
|
||||
// Run inference with the loaded input
|
||||
println!("\nRunning model inference with test input...");
|
||||
let start = Instant::now();
|
||||
|
||||
let (text_embeds, last_hidden_state) = model.forward(input_ids, attention_mask);
|
||||
|
||||
let inference_time = start.elapsed();
|
||||
println!(" Inference completed in {:.2?}", inference_time);
|
||||
|
||||
// Display output shapes
|
||||
let text_embeds_shape = text_embeds.shape();
|
||||
let last_hidden_shape = last_hidden_state.shape();
|
||||
println!("\n Model output shapes:");
|
||||
println!(" text_embeds: {:?}", text_embeds_shape.dims);
|
||||
println!(" last_hidden_state: {:?}", last_hidden_shape.dims);
|
||||
|
||||
// Verify expected output shapes match
|
||||
if text_embeds_shape.dims == ref_text_embeds_shape.dims {
|
||||
println!(
|
||||
" ✓ text_embeds shape matches expected: {:?}",
|
||||
ref_text_embeds_shape.dims
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" ⚠ Warning: Expected text_embeds shape {:?}, got {:?}",
|
||||
ref_text_embeds_shape.dims, text_embeds_shape.dims
|
||||
);
|
||||
}
|
||||
|
||||
if last_hidden_shape.dims == ref_last_hidden_shape.dims {
|
||||
println!(
|
||||
" ✓ last_hidden_state shape matches expected: {:?}",
|
||||
ref_last_hidden_shape.dims
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" ⚠ Warning: Expected last_hidden_state shape {:?}, got {:?}",
|
||||
ref_last_hidden_shape.dims, last_hidden_shape.dims
|
||||
);
|
||||
}
|
||||
|
||||
// Compare outputs
|
||||
println!("\nComparing model outputs with reference data...");
|
||||
|
||||
// Check if text_embeds are close
|
||||
println!("\n Checking text_embeds:");
|
||||
if text_embeds
|
||||
.clone()
|
||||
.all_close(reference_text_embeds.clone(), Some(1e-4), Some(1e-4))
|
||||
{
|
||||
println!(" ✓ text_embeds matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ text_embeds differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = text_embeds.clone() - reference_text_embeds.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = text_embeds.clone().flatten::<1>(0, 1);
|
||||
let reference_flat = reference_text_embeds.clone().flatten::<1>(0, 1);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if last_hidden_state is close
|
||||
println!("\n Checking last_hidden_state:");
|
||||
if last_hidden_state.clone().all_close(
|
||||
reference_last_hidden_state.clone(),
|
||||
Some(1e-4),
|
||||
Some(1e-4),
|
||||
) {
|
||||
println!(" ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ last_hidden_state differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = last_hidden_state.clone() - reference_last_hidden_state.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = last_hidden_state.clone().flatten::<1>(0, 2);
|
||||
let reference_flat = reference_last_hidden_state.clone().flatten::<1>(0, 2);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n========================================");
|
||||
println!("Model test completed!");
|
||||
println!("========================================");
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "burn-onnx-model-checks-clip-vit-b-32-vision"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[workspace]
|
||||
|
||||
[features]
|
||||
default = ["tch"]
|
||||
ndarray = []
|
||||
tch = []
|
||||
wgpu = []
|
||||
metal = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../../crates/burn", features = [
|
||||
"ndarray",
|
||||
"tch",
|
||||
"wgpu",
|
||||
"metal",
|
||||
] }
|
||||
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../../../burn-onnx" }
|
||||
@@ -1,35 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
use std::path::Path;
|
||||
|
||||
fn main() {
|
||||
let onnx_path = "artifacts/clip-vit-b-32-vision_opset16.onnx";
|
||||
let test_data_path = "artifacts/test_data.pt";
|
||||
|
||||
// Tell Cargo to only rebuild if these files change
|
||||
println!("cargo:rerun-if-changed={}", onnx_path);
|
||||
println!("cargo:rerun-if-changed={}", test_data_path);
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
// Check if the ONNX model file exists
|
||||
if !Path::new(onnx_path).exists() {
|
||||
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
|
||||
eprintln!();
|
||||
eprintln!("Please run the following command to download and prepare the model:");
|
||||
eprintln!(" python get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("Or if you prefer using uv:");
|
||||
eprintln!(" uv run get_model.py");
|
||||
eprintln!();
|
||||
eprintln!(
|
||||
"This will download the CLIP ViT-B-32-vision model and convert it to ONNX format."
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Generate the model code from the ONNX file
|
||||
// Use double precision to handle large Int64 constants in CLIP
|
||||
ModelGen::new()
|
||||
.input(onnx_path)
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
}
|
||||
@@ -1,279 +0,0 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "onnx==1.19.0",
|
||||
# "onnxruntime>=1.22.0",
|
||||
# "huggingface-hub>=0.20.0",
|
||||
# "numpy",
|
||||
# "torch",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import os
|
||||
import sys
|
||||
import onnx
|
||||
from onnx import shape_inference, version_converter
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def download_clip_model(output_path):
|
||||
"""Download CLIP ViT-B-32-vision model from Hugging Face."""
|
||||
print("Downloading CLIP ViT-B-32-vision model from Hugging Face...")
|
||||
|
||||
# Download the ONNX model from Hugging Face
|
||||
model_path = hf_hub_download(
|
||||
repo_id="Qdrant/clip-ViT-B-32-vision",
|
||||
filename="model.onnx",
|
||||
cache_dir="./artifacts/cache",
|
||||
)
|
||||
|
||||
# Copy to artifacts
|
||||
import shutil
|
||||
|
||||
shutil.copy(model_path, output_path)
|
||||
|
||||
if not output_path.exists():
|
||||
raise FileNotFoundError(f"Failed to download ONNX file to {output_path}")
|
||||
|
||||
print(f"✓ Model downloaded to: {output_path}")
|
||||
|
||||
|
||||
def process_model(input_path, output_path, target_opset=16):
|
||||
"""Load, upgrade opset, and apply shape inference to model."""
|
||||
print(f"Loading model from {input_path}...")
|
||||
model = onnx.load(input_path)
|
||||
|
||||
# Check and upgrade opset if needed
|
||||
current_opset = model.opset_import[0].version
|
||||
if current_opset < target_opset:
|
||||
print(f"Upgrading opset from {current_opset} to {target_opset}...")
|
||||
model = version_converter.convert_version(model, target_opset)
|
||||
|
||||
# Apply shape inference
|
||||
print("Applying shape inference...")
|
||||
model = shape_inference.infer_shapes(model)
|
||||
|
||||
# Save processed model
|
||||
onnx.save(model, output_path)
|
||||
print(f"✓ Processed model saved to: {output_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_input_info(model):
|
||||
"""Extract input information from ONNX model."""
|
||||
inputs = []
|
||||
for input_info in model.graph.input:
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
# Use proper defaults for CLIP vision model
|
||||
if "pixel_values" in input_info.name:
|
||||
# CLIP vision uses [batch, channels, height, width]
|
||||
if len(shape) == 0:
|
||||
shape.append(1) # batch
|
||||
elif len(shape) == 1:
|
||||
shape.append(3) # channels
|
||||
elif len(shape) == 2:
|
||||
shape.append(224) # height
|
||||
elif len(shape) == 3:
|
||||
shape.append(224) # width
|
||||
else:
|
||||
shape.append(1) # Default to 1 for other dynamic dimensions
|
||||
inputs.append(
|
||||
{
|
||||
"name": input_info.name,
|
||||
"shape": shape,
|
||||
"dtype": input_info.type.tensor_type.elem_type,
|
||||
}
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def generate_test_data(model_path, output_dir):
|
||||
"""Generate test input/output data and save as PyTorch tensors."""
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
print("\nGenerating test data...")
|
||||
|
||||
# Load model to get input shapes
|
||||
model = onnx.load(model_path)
|
||||
input_infos = get_input_info(model)
|
||||
|
||||
print(f" Model has {len(input_infos)} inputs:")
|
||||
for info in input_infos:
|
||||
print(f" - {info['name']}: shape={info['shape']}, dtype={info['dtype']}")
|
||||
|
||||
# Create reproducible test inputs
|
||||
np.random.seed(42)
|
||||
test_inputs = {}
|
||||
|
||||
for info in input_infos:
|
||||
if info["dtype"] == onnx.TensorProto.INT64:
|
||||
# For INT64 inputs, use random integers
|
||||
test_input = np.random.randint(0, 1000, size=info["shape"], dtype=np.int64)
|
||||
else:
|
||||
# For float inputs (like pixel_values), use random floats
|
||||
test_input = np.random.rand(*info["shape"]).astype(np.float32)
|
||||
test_inputs[info["name"]] = test_input
|
||||
|
||||
# Run inference to get output
|
||||
session = ort.InferenceSession(model_path)
|
||||
outputs = session.run(None, test_inputs)
|
||||
|
||||
# Save in a format that's easier to load in Rust
|
||||
# For CLIP vision, we expect:
|
||||
# - Inputs: pixel_values
|
||||
# - Outputs: image_embeds (2D)
|
||||
|
||||
# Create a more structured format for Rust
|
||||
test_data = {
|
||||
"pixel_values": torch.from_numpy(
|
||||
test_inputs.get("pixel_values", list(test_inputs.values())[0])
|
||||
),
|
||||
"image_embeds": torch.from_numpy(outputs[0]),
|
||||
}
|
||||
|
||||
test_data_path = Path(output_dir) / "test_data.pt"
|
||||
torch.save(test_data, test_data_path)
|
||||
|
||||
print(f" ✓ Test data saved to: {test_data_path}")
|
||||
print(f" Input shapes:")
|
||||
print(f" pixel_values: {test_data['pixel_values'].shape}")
|
||||
print(f" Output shapes:")
|
||||
print(f" image_embeds: {test_data['image_embeds'].shape}")
|
||||
|
||||
|
||||
def save_model_info(model_path, output_dir):
|
||||
"""Save model structure information to a text file."""
|
||||
print("\nSaving model information...")
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
info_path = Path(output_dir) / "model-python.txt"
|
||||
with open(info_path, "w") as f:
|
||||
f.write("CLIP ViT-B-32-vision Model Information\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
# Input information
|
||||
f.write("Inputs:\n")
|
||||
for input_info in model.graph.input:
|
||||
f.write(f" - {input_info.name}\n")
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append("dynamic")
|
||||
f.write(f" Shape: {shape}\n")
|
||||
f.write(
|
||||
f" Type: {onnx.TensorProto.DataType.Name(input_info.type.tensor_type.elem_type)}\n"
|
||||
)
|
||||
|
||||
# Output information
|
||||
f.write("\nOutputs:\n")
|
||||
for output_info in model.graph.output:
|
||||
f.write(f" - {output_info.name}\n")
|
||||
shape = []
|
||||
for dim in output_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append("dynamic")
|
||||
f.write(f" Shape: {shape}\n")
|
||||
f.write(
|
||||
f" Type: {onnx.TensorProto.DataType.Name(output_info.type.tensor_type.elem_type)}\n"
|
||||
)
|
||||
|
||||
# Model statistics
|
||||
f.write(f"\nModel Statistics:\n")
|
||||
f.write(f" Opset version: {model.opset_import[0].version}\n")
|
||||
f.write(f" Number of nodes: {len(model.graph.node)}\n")
|
||||
f.write(f" Number of initializers: {len(model.graph.initializer)}\n")
|
||||
|
||||
# Node types summary
|
||||
node_types = {}
|
||||
for node in model.graph.node:
|
||||
op_type = node.op_type
|
||||
node_types[op_type] = node_types.get(op_type, 0) + 1
|
||||
|
||||
f.write(f"\nNode types:\n")
|
||||
for op_type, count in sorted(node_types.items()):
|
||||
f.write(f" {op_type}: {count}\n")
|
||||
|
||||
print(f" ✓ Model info saved to: {info_path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("CLIP ViT-B-32-vision Model Preparation Tool")
|
||||
print("=" * 60)
|
||||
|
||||
# Setup paths
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
original_path = artifacts_dir / "clip-vit-b-32-vision.onnx"
|
||||
processed_path = artifacts_dir / "clip-vit-b-32-vision_opset16.onnx"
|
||||
test_data_path = artifacts_dir / "test_data.pt"
|
||||
model_info_path = artifacts_dir / "model-python.txt"
|
||||
|
||||
# Check if we already have everything
|
||||
if processed_path.exists() and test_data_path.exists() and model_info_path.exists():
|
||||
print(f"\n✓ All files already exist:")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print(f" Model info: {model_info_path}")
|
||||
print("\nNothing to do!")
|
||||
return
|
||||
|
||||
# Download model if needed
|
||||
if not original_path.exists() and not processed_path.exists():
|
||||
print("\nStep 1: Downloading CLIP model...")
|
||||
download_clip_model(original_path)
|
||||
|
||||
# Process model if needed
|
||||
if not processed_path.exists():
|
||||
print("\nStep 2: Processing model...")
|
||||
process_model(original_path, processed_path, target_opset=16)
|
||||
|
||||
# Clean up original if we have the processed version
|
||||
if original_path.exists() and processed_path.exists():
|
||||
original_path.unlink()
|
||||
|
||||
# Generate test data if needed
|
||||
if not test_data_path.exists():
|
||||
print("\nStep 3: Generating test data...")
|
||||
generate_test_data(processed_path, artifacts_dir)
|
||||
|
||||
# Save model info if needed
|
||||
if not model_info_path.exists():
|
||||
print("\nStep 4: Saving model information...")
|
||||
save_model_info(processed_path, artifacts_dir)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ CLIP model preparation completed!")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print(f" Model info: {model_info_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠ Operation cancelled by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -1,176 +0,0 @@
|
||||
extern crate alloc;
|
||||
|
||||
use burn::module::{Initializer, Param};
|
||||
use burn::prelude::*;
|
||||
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type MyBackend = burn::backend::Wgpu;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type MyBackend = burn::backend::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub type MyBackend = burn::backend::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub type MyBackend = burn::backend::Metal;
|
||||
|
||||
// Import the generated model code as a module
|
||||
pub mod clip_vit_b_32_vision {
|
||||
include!(concat!(
|
||||
env!("OUT_DIR"),
|
||||
"/model/clip-vit-b-32-vision_opset16.rs"
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Debug, Module)]
|
||||
struct TestData<B: Backend> {
|
||||
pixel_values: Param<Tensor<B, 4>>,
|
||||
image_embeds: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TestData<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
// CLIP ViT-B-32 vision: image_size=224, embed_dim=512
|
||||
Self {
|
||||
pixel_values: Initializer::Zeros.init([1, 3, 224, 224], device),
|
||||
image_embeds: Initializer::Zeros.init([1, 512], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("========================================");
|
||||
println!("CLIP ViT-B-32-vision Burn Model Test");
|
||||
println!("========================================\n");
|
||||
|
||||
// Check if artifacts exist
|
||||
let artifacts_dir = Path::new("artifacts");
|
||||
if !artifacts_dir.exists() {
|
||||
eprintln!("Error: artifacts directory not found!");
|
||||
eprintln!("Please run get_model.py first to download the model and test data.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Initialize the model (using default which includes the converted weights)
|
||||
println!("Initializing CLIP vision model...");
|
||||
let start = Instant::now();
|
||||
let device = Default::default();
|
||||
let model: clip_vit_b_32_vision::Model<MyBackend> = clip_vit_b_32_vision::Model::default();
|
||||
let init_time = start.elapsed();
|
||||
println!(" Model initialized in {:.2?}", init_time);
|
||||
|
||||
// Save model structure to file
|
||||
println!("\nSaving model structure to artifacts/model.txt...");
|
||||
let model_str = format!("{}", model);
|
||||
std::fs::write("artifacts/model.txt", &model_str)
|
||||
.expect("Failed to write model structure to file");
|
||||
println!(" Model structure saved");
|
||||
|
||||
// Load test data from PyTorch file
|
||||
println!("\nLoading test data from artifacts/test_data.pt...");
|
||||
let start = Instant::now();
|
||||
let mut test_data = TestData::<MyBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file("artifacts/test_data.pt");
|
||||
test_data.load_from(&mut store).expect("Failed to load test data");
|
||||
let load_time = start.elapsed();
|
||||
println!(" Data loaded in {:.2?}", load_time);
|
||||
|
||||
// Get the input tensors from test data
|
||||
let pixel_values = test_data.pixel_values.val();
|
||||
let pixel_values_shape = pixel_values.shape();
|
||||
println!(
|
||||
" Loaded pixel_values with shape: {:?}",
|
||||
pixel_values_shape.dims
|
||||
);
|
||||
|
||||
// Get the reference outputs from test data
|
||||
let reference_image_embeds = test_data.image_embeds.val();
|
||||
let ref_image_embeds_shape = reference_image_embeds.shape();
|
||||
println!(
|
||||
" Loaded reference image_embeds with shape: {:?}",
|
||||
ref_image_embeds_shape.dims
|
||||
);
|
||||
|
||||
// Run inference with the loaded input
|
||||
println!("\nRunning model inference with test input...");
|
||||
let start = Instant::now();
|
||||
|
||||
let image_embeds = model.forward(pixel_values);
|
||||
|
||||
let inference_time = start.elapsed();
|
||||
println!(" Inference completed in {:.2?}", inference_time);
|
||||
|
||||
// Display output shapes
|
||||
let image_embeds_shape = image_embeds.shape();
|
||||
println!("\n Model output shapes:");
|
||||
println!(" image_embeds: {:?}", image_embeds_shape.dims);
|
||||
|
||||
// Verify expected output shapes match
|
||||
if image_embeds_shape.dims == ref_image_embeds_shape.dims {
|
||||
println!(
|
||||
" ✓ image_embeds shape matches expected: {:?}",
|
||||
ref_image_embeds_shape.dims
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" ⚠ Warning: Expected image_embeds shape {:?}, got {:?}",
|
||||
ref_image_embeds_shape.dims, image_embeds_shape.dims
|
||||
);
|
||||
}
|
||||
|
||||
// Compare outputs
|
||||
println!("\nComparing model outputs with reference data...");
|
||||
|
||||
// Check if image_embeds are close
|
||||
println!("\n Checking image_embeds:");
|
||||
if image_embeds
|
||||
.clone()
|
||||
.all_close(reference_image_embeds.clone(), Some(1e-4), Some(1e-4))
|
||||
{
|
||||
println!(" ✓ image_embeds matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ image_embeds differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = image_embeds.clone() - reference_image_embeds.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = image_embeds.clone().flatten::<1>(0, 1);
|
||||
let reference_flat = reference_image_embeds.clone().flatten::<1>(0, 1);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
println!("\n========================================");
|
||||
println!("Summary:");
|
||||
println!(" - Model initialization: {:.2?}", init_time);
|
||||
println!(" - Data loading: {:.2?}", load_time);
|
||||
println!(" - Inference time: {:.2?}", inference_time);
|
||||
println!(" - Output validation: ✓ Passed");
|
||||
println!("========================================");
|
||||
println!("Model test completed successfully!");
|
||||
println!("========================================");
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "burn-onnx-model-checks-modernbert-base"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[workspace]
|
||||
|
||||
[features]
|
||||
default = ["tch"]
|
||||
ndarray = []
|
||||
tch = []
|
||||
wgpu = []
|
||||
metal = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../../crates/burn", features = [
|
||||
"ndarray",
|
||||
"tch",
|
||||
"wgpu",
|
||||
"metal",
|
||||
] }
|
||||
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../../../burn-onnx" }
|
||||
@@ -1,72 +0,0 @@
|
||||
# ModernBERT-base Model Check
|
||||
|
||||
This crate provides testing for the ModernBERT-base model with Burn.
|
||||
|
||||
## Model
|
||||
|
||||
- `ModernBERT-base` - Modern BERT variant from Answer.AI
|
||||
(https://huggingface.co/answerdotai/ModernBERT-base)
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Download and prepare the model
|
||||
|
||||
```bash
|
||||
# Using Python directly
|
||||
python get_model.py
|
||||
|
||||
# Or using uv
|
||||
uv run get_model.py
|
||||
```
|
||||
|
||||
**Note:** This will download the model from HuggingFace and export it to ONNX format using PyTorch.
|
||||
Make sure you have the `transformers` library installed.
|
||||
|
||||
### 2. Build and run the model test
|
||||
|
||||
```bash
|
||||
# Build the model
|
||||
cargo build
|
||||
|
||||
# Run the test
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
modernbert-base/
|
||||
├── artifacts/ # Downloaded ONNX model and test data
|
||||
│ ├── modernbert-base_opset16.onnx
|
||||
│ ├── test_data.pt
|
||||
│ └── model-python.txt
|
||||
├── src/
|
||||
│ └── main.rs # Test runner
|
||||
├── build.rs # Build script that generates model code
|
||||
├── get_model.py # Model download and ONNX export script
|
||||
├── Cargo.toml
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Model Architecture
|
||||
|
||||
ModernBERT is a modern variant of BERT designed by Answer.AI with improved efficiency and
|
||||
performance. It features several architectural improvements over the original BERT, including better
|
||||
positional embeddings and optimized attention mechanisms.
|
||||
|
||||
- **Inputs**:
|
||||
- `input_ids`: Token IDs (shape: [batch_size, sequence_length])
|
||||
- `attention_mask`: Attention mask (shape: [batch_size, sequence_length])
|
||||
|
||||
- **Outputs**:
|
||||
- `last_hidden_state`: Sequence of hidden states (shape: [batch_size, sequence_length, 768])
|
||||
- `pooled_output`: Mean-pooled sentence embeddings (computed, shape: [batch_size, 768])
|
||||
|
||||
## Notes
|
||||
|
||||
- The default sequence length is 512 tokens
|
||||
- The model has a hidden size of 768
|
||||
- The model uses ONNX opset 16
|
||||
- Test data is generated with random inputs for reproducibility (seed=42)
|
||||
- Vocabulary size: 50,368 tokens
|
||||
- The pooled output is computed using mean pooling over the last hidden state
|
||||
@@ -1,32 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
use std::path::Path;
|
||||
|
||||
fn main() {
|
||||
let onnx_path = "artifacts/modernbert-base_opset16.onnx";
|
||||
let test_data_path = "artifacts/test_data.pt";
|
||||
|
||||
// Tell Cargo to only rebuild if these files change
|
||||
println!("cargo:rerun-if-changed={}", onnx_path);
|
||||
println!("cargo:rerun-if-changed={}", test_data_path);
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
// Check if the ONNX model file exists
|
||||
if !Path::new(onnx_path).exists() {
|
||||
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
|
||||
eprintln!();
|
||||
eprintln!("Please run the following command to download and prepare the model:");
|
||||
eprintln!(" python get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("Or if you prefer using uv:");
|
||||
eprintln!(" uv run get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("This will download the ModernBERT-base model and convert it to ONNX format.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Generate the model code from the ONNX file
|
||||
ModelGen::new()
|
||||
.input(onnx_path)
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
}
|
||||
@@ -1,335 +0,0 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "onnx==1.19.0",
|
||||
# "onnxruntime>=1.22.0",
|
||||
# "huggingface-hub>=0.20.0",
|
||||
# "numpy",
|
||||
# "torch",
|
||||
# "transformers>=4.46.0",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import os
|
||||
import sys
|
||||
import onnx
|
||||
from onnx import shape_inference, version_converter
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def download_modernbert_model(output_path):
|
||||
"""Download ModernBERT-base model from Hugging Face and export to ONNX."""
|
||||
print("Downloading ModernBERT-base model from Hugging Face...")
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Download the model
|
||||
model_name = "answerdotai/ModernBERT-base"
|
||||
print(f" Loading {model_name}...")
|
||||
model = AutoModel.from_pretrained(model_name, cache_dir="./artifacts/cache")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./artifacts/cache")
|
||||
|
||||
# Set model to evaluation mode
|
||||
model.eval()
|
||||
|
||||
# Create dummy input
|
||||
dummy_text = "This is a test sentence."
|
||||
inputs = tokenizer(
|
||||
dummy_text, return_tensors="pt", padding="max_length", max_length=512
|
||||
)
|
||||
|
||||
# Export to ONNX
|
||||
print(f" Exporting model to ONNX...")
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(inputs["input_ids"], inputs["attention_mask"]),
|
||||
str(output_path),
|
||||
input_names=["input_ids", "attention_mask"],
|
||||
output_names=["last_hidden_state"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "sequence"},
|
||||
"attention_mask": {0: "batch", 1: "sequence"},
|
||||
"last_hidden_state": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset_version=16,
|
||||
)
|
||||
|
||||
print(f"✓ Model exported to ONNX: {output_path}")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"\n✗ Error: Missing required package.")
|
||||
print(f" {str(e)}")
|
||||
print("\nPlease install transformers:")
|
||||
print(" uv pip install transformers")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def process_model(input_path, output_path, target_opset=16):
|
||||
"""Load, upgrade opset, and apply shape inference to model."""
|
||||
print(f"Loading model from {input_path}...")
|
||||
model = onnx.load(input_path)
|
||||
|
||||
# Check and upgrade opset if needed
|
||||
current_opset = model.opset_import[0].version
|
||||
if current_opset < target_opset:
|
||||
print(f"Upgrading opset from {current_opset} to {target_opset}...")
|
||||
model = version_converter.convert_version(model, target_opset)
|
||||
|
||||
# Apply shape inference
|
||||
print("Applying shape inference...")
|
||||
model = shape_inference.infer_shapes(model)
|
||||
|
||||
# Save processed model
|
||||
onnx.save(model, output_path)
|
||||
print(f"✓ Processed model saved to: {output_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_input_info(model):
|
||||
"""Extract input information from ONNX model."""
|
||||
inputs = []
|
||||
for input_info in model.graph.input:
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
# Use proper defaults for ModernBERT
|
||||
if (
|
||||
"input_ids" in input_info.name
|
||||
or "attention_mask" in input_info.name
|
||||
):
|
||||
# Default sequence length
|
||||
shape.append(1 if len(shape) == 0 else 512)
|
||||
else:
|
||||
shape.append(1) # Default to 1 for other dynamic dimensions
|
||||
inputs.append(
|
||||
{
|
||||
"name": input_info.name,
|
||||
"shape": shape,
|
||||
"dtype": input_info.type.tensor_type.elem_type,
|
||||
}
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
def generate_test_data(model_path, output_dir):
|
||||
"""Generate test input/output data and save as PyTorch tensors."""
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
print("\nGenerating test data...")
|
||||
|
||||
# Load model to get input shapes
|
||||
model = onnx.load(model_path)
|
||||
input_infos = get_input_info(model)
|
||||
|
||||
print(f" Model has {len(input_infos)} inputs:")
|
||||
for info in input_infos:
|
||||
print(f" - {info['name']}: shape={info['shape']}, dtype={info['dtype']}")
|
||||
|
||||
# Create reproducible test inputs
|
||||
np.random.seed(42)
|
||||
test_inputs = {}
|
||||
|
||||
for info in input_infos:
|
||||
if info["dtype"] == onnx.TensorProto.INT64:
|
||||
# Handle different integer inputs appropriately
|
||||
if "attention_mask" in info["name"]:
|
||||
# Attention mask should be 1 for valid tokens, 0 for padding
|
||||
# For testing, use all 1s (all valid tokens)
|
||||
test_input = np.ones(info["shape"], dtype=np.int64)
|
||||
elif "input_ids" in info["name"]:
|
||||
# For input_ids, use random integers in vocabulary range
|
||||
# ModernBERT uses a 50368 vocabulary
|
||||
test_input = np.random.randint(0, 50368, size=info["shape"], dtype=np.int64)
|
||||
else:
|
||||
# For other INT64 inputs, use random integers
|
||||
test_input = np.random.randint(0, 50368, size=info["shape"], dtype=np.int64)
|
||||
else:
|
||||
# For float inputs, use random floats
|
||||
test_input = np.random.rand(*info["shape"]).astype(np.float32)
|
||||
test_inputs[info["name"]] = test_input
|
||||
|
||||
# Run inference to get output
|
||||
session = ort.InferenceSession(model_path)
|
||||
outputs = session.run(None, test_inputs)
|
||||
|
||||
# Save in a format that's easier to load in Rust
|
||||
# For ModernBERT, we expect:
|
||||
# - Inputs: input_ids, attention_mask
|
||||
# - Outputs: last_hidden_state (3D)
|
||||
|
||||
# Compute mean pooled output
|
||||
last_hidden_state = outputs[0]
|
||||
attention_mask_np = test_inputs.get("attention_mask")
|
||||
|
||||
# Mean pooling - take attention mask into account for correct averaging
|
||||
input_mask_expanded = np.expand_dims(attention_mask_np, axis=-1).astype(np.float32)
|
||||
sum_embeddings = np.sum(last_hidden_state * input_mask_expanded, axis=1)
|
||||
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
|
||||
pooled_output = sum_embeddings / sum_mask
|
||||
|
||||
# Create a more structured format for Rust
|
||||
test_data = {
|
||||
"input_ids": torch.from_numpy(
|
||||
test_inputs.get(
|
||||
"input_ids", test_inputs.get("inputs.0", list(test_inputs.values())[0])
|
||||
)
|
||||
),
|
||||
"attention_mask": torch.from_numpy(
|
||||
test_inputs.get(
|
||||
"attention_mask",
|
||||
test_inputs.get("inputs.1", list(test_inputs.values())[1]),
|
||||
)
|
||||
),
|
||||
"last_hidden_state": torch.from_numpy(outputs[0]),
|
||||
"pooled_output": torch.from_numpy(pooled_output),
|
||||
}
|
||||
|
||||
test_data_path = Path(output_dir) / "test_data.pt"
|
||||
torch.save(test_data, test_data_path)
|
||||
|
||||
print(f" ✓ Test data saved to: {test_data_path}")
|
||||
print(f" Input shapes:")
|
||||
print(f" input_ids: {test_data['input_ids'].shape}")
|
||||
print(f" attention_mask: {test_data['attention_mask'].shape}")
|
||||
print(f" Output shapes:")
|
||||
print(f" last_hidden_state: {test_data['last_hidden_state'].shape}")
|
||||
print(f" pooled_output: {test_data['pooled_output'].shape}")
|
||||
|
||||
|
||||
def save_model_info(model_path, output_dir):
|
||||
"""Save model structure information to a text file."""
|
||||
print("\nSaving model information...")
|
||||
|
||||
model = onnx.load(model_path)
|
||||
|
||||
info_path = Path(output_dir) / "model-python.txt"
|
||||
with open(info_path, "w") as f:
|
||||
f.write("ModernBERT-base Model Information\n")
|
||||
f.write("=" * 60 + "\n\n")
|
||||
|
||||
# Input information
|
||||
f.write("Inputs:\n")
|
||||
for input_info in model.graph.input:
|
||||
f.write(f" - {input_info.name}\n")
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append("dynamic")
|
||||
f.write(f" Shape: {shape}\n")
|
||||
f.write(
|
||||
f" Type: {onnx.TensorProto.DataType.Name(input_info.type.tensor_type.elem_type)}\n"
|
||||
)
|
||||
|
||||
# Output information
|
||||
f.write("\nOutputs:\n")
|
||||
for output_info in model.graph.output:
|
||||
f.write(f" - {output_info.name}\n")
|
||||
shape = []
|
||||
for dim in output_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append("dynamic")
|
||||
f.write(f" Shape: {shape}\n")
|
||||
f.write(
|
||||
f" Type: {onnx.TensorProto.DataType.Name(output_info.type.tensor_type.elem_type)}\n"
|
||||
)
|
||||
|
||||
# Model statistics
|
||||
f.write(f"\nModel Statistics:\n")
|
||||
f.write(f" Opset version: {model.opset_import[0].version}\n")
|
||||
f.write(f" Number of nodes: {len(model.graph.node)}\n")
|
||||
f.write(f" Number of initializers: {len(model.graph.initializer)}\n")
|
||||
|
||||
# Node types summary
|
||||
node_types = {}
|
||||
for node in model.graph.node:
|
||||
op_type = node.op_type
|
||||
node_types[op_type] = node_types.get(op_type, 0) + 1
|
||||
|
||||
f.write(f"\nNode types:\n")
|
||||
for op_type, count in sorted(node_types.items()):
|
||||
f.write(f" {op_type}: {count}\n")
|
||||
|
||||
print(f" ✓ Model info saved to: {info_path}")
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("ModernBERT-base Model Preparation Tool")
|
||||
print("=" * 60)
|
||||
|
||||
# Setup paths
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
original_path = artifacts_dir / "modernbert-base.onnx"
|
||||
processed_path = artifacts_dir / "modernbert-base_opset16.onnx"
|
||||
test_data_path = artifacts_dir / "test_data.pt"
|
||||
model_info_path = artifacts_dir / "model-python.txt"
|
||||
|
||||
# Check if we already have everything
|
||||
if processed_path.exists() and test_data_path.exists() and model_info_path.exists():
|
||||
print(f"\n✓ All files already exist:")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print(f" Model info: {model_info_path}")
|
||||
print("\nNothing to do!")
|
||||
return
|
||||
|
||||
# Download and export model if needed
|
||||
if not original_path.exists() and not processed_path.exists():
|
||||
print("\nStep 1: Downloading and exporting ModernBERT-base model...")
|
||||
download_modernbert_model(original_path)
|
||||
|
||||
# Process model if needed (already at opset 16, but apply shape inference)
|
||||
if not processed_path.exists():
|
||||
if original_path.exists():
|
||||
print("\nStep 2: Processing model...")
|
||||
process_model(original_path, processed_path, target_opset=16)
|
||||
|
||||
# Clean up original if we have the processed version
|
||||
if original_path.exists() and processed_path.exists():
|
||||
original_path.unlink()
|
||||
|
||||
# Generate test data if needed
|
||||
if not test_data_path.exists():
|
||||
print("\nStep 3: Generating test data...")
|
||||
generate_test_data(processed_path, artifacts_dir)
|
||||
|
||||
# Save model info if needed
|
||||
if not model_info_path.exists():
|
||||
print("\nStep 4: Saving model information...")
|
||||
save_model_info(processed_path, artifacts_dir)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ ModernBERT-base model preparation completed!")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print(f" Model info: {model_info_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠ Operation cancelled by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -1,271 +0,0 @@
|
||||
extern crate alloc;
|
||||
|
||||
use burn::module::{Initializer, Param};
|
||||
use burn::prelude::*;
|
||||
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type MyBackend = burn::backend::Wgpu;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type MyBackend = burn::backend::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub type MyBackend = burn::backend::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub type MyBackend = burn::backend::Metal;
|
||||
|
||||
// Import the generated model code as a module
|
||||
pub mod modernbert_base {
|
||||
include!(concat!(
|
||||
env!("OUT_DIR"),
|
||||
"/model/modernbert-base_opset16.rs"
|
||||
));
|
||||
}
|
||||
|
||||
#[derive(Debug, Module)]
|
||||
struct TestData<B: Backend> {
|
||||
input_ids: Param<Tensor<B, 2, Int>>,
|
||||
attention_mask: Param<Tensor<B, 2, Int>>,
|
||||
last_hidden_state: Param<Tensor<B, 3>>,
|
||||
pooled_output: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TestData<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
use burn::module::ParamId;
|
||||
// ModernBERT-base: sequence_length=512, hidden_size=768
|
||||
// Note: Initializer only works for float tensors, Int tensors need manual init
|
||||
Self {
|
||||
input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 512], device)),
|
||||
attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 512], device)),
|
||||
last_hidden_state: Initializer::Zeros.init([1, 512, 768], device),
|
||||
pooled_output: Initializer::Zeros.init([1, 768], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply mean pooling to get sentence embeddings
|
||||
fn mean_pool<B: Backend>(
|
||||
last_hidden_state: Tensor<B, 3>,
|
||||
attention_mask: Tensor<B, 2, Int>,
|
||||
) -> Tensor<B, 2> {
|
||||
// Convert attention_mask to float and expand dimensions to match hidden_state
|
||||
let attention_mask_float = attention_mask.float().unsqueeze_dim::<3>(2);
|
||||
|
||||
// Multiply hidden states by attention mask
|
||||
let masked_embeddings = last_hidden_state * attention_mask_float.clone();
|
||||
|
||||
// Sum along sequence dimension (dim 1)
|
||||
let sum_embeddings = masked_embeddings.sum_dim(1);
|
||||
|
||||
// Sum attention mask to get count of non-padding tokens
|
||||
let sum_mask = attention_mask_float.sum_dim(1).clamp_min(1e-9);
|
||||
|
||||
// Divide to get mean - result is [batch, 1, hidden]
|
||||
let pooled = sum_embeddings / sum_mask;
|
||||
|
||||
// Get the shape to reshape to [batch, hidden]
|
||||
let shape = pooled.shape();
|
||||
let batch_size = shape.dims[0];
|
||||
let hidden_size = shape.dims[2];
|
||||
|
||||
pooled.reshape([batch_size, hidden_size])
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("========================================");
|
||||
println!("ModernBERT-base Burn Model Test");
|
||||
println!("========================================\n");
|
||||
|
||||
// Check if artifacts exist
|
||||
let artifacts_dir = Path::new("artifacts");
|
||||
if !artifacts_dir.exists() {
|
||||
eprintln!("Error: artifacts directory not found!");
|
||||
eprintln!("Please run get_model.py first to download the model and test data.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Initialize the model (using default which includes the converted weights)
|
||||
println!("Initializing ModernBERT-base model...");
|
||||
let start = Instant::now();
|
||||
let device = Default::default();
|
||||
let model: modernbert_base::Model<MyBackend> = modernbert_base::Model::default();
|
||||
let init_time = start.elapsed();
|
||||
println!(" Model initialized in {:.2?}", init_time);
|
||||
|
||||
// Save model structure to file
|
||||
println!("\nSaving model structure to artifacts/model.txt...");
|
||||
let model_str = format!("{}", model);
|
||||
std::fs::write("artifacts/model.txt", &model_str)
|
||||
.expect("Failed to write model structure to file");
|
||||
println!(" Model structure saved");
|
||||
|
||||
// Load test data from PyTorch file
|
||||
println!("\nLoading test data from artifacts/test_data.pt...");
|
||||
let start = Instant::now();
|
||||
let mut test_data = TestData::<MyBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file("artifacts/test_data.pt");
|
||||
test_data.load_from(&mut store).expect("Failed to load test data");
|
||||
let load_time = start.elapsed();
|
||||
println!(" Data loaded in {:.2?}", load_time);
|
||||
|
||||
// Get the input tensors from test data
|
||||
let input_ids = test_data.input_ids.val();
|
||||
let attention_mask = test_data.attention_mask.val();
|
||||
let input_ids_shape = input_ids.shape();
|
||||
let attention_mask_shape = attention_mask.shape();
|
||||
println!(" Loaded input_ids with shape: {:?}", input_ids_shape.dims);
|
||||
println!(
|
||||
" Loaded attention_mask with shape: {:?}",
|
||||
attention_mask_shape.dims
|
||||
);
|
||||
|
||||
// Get the reference outputs from test data
|
||||
let reference_last_hidden_state = test_data.last_hidden_state.val();
|
||||
let reference_pooled_output = test_data.pooled_output.val();
|
||||
let ref_last_hidden_shape = reference_last_hidden_state.shape();
|
||||
let ref_pooled_shape = reference_pooled_output.shape();
|
||||
println!(
|
||||
" Loaded reference last_hidden_state with shape: {:?}",
|
||||
ref_last_hidden_shape.dims
|
||||
);
|
||||
println!(
|
||||
" Loaded reference pooled_output with shape: {:?}",
|
||||
ref_pooled_shape.dims
|
||||
);
|
||||
|
||||
// Run inference with the loaded input
|
||||
println!("\nRunning model inference with test input...");
|
||||
let start = Instant::now();
|
||||
|
||||
let last_hidden_state = model.forward(input_ids.clone(), attention_mask.clone());
|
||||
|
||||
let inference_time = start.elapsed();
|
||||
println!(" Inference completed in {:.2?}", inference_time);
|
||||
|
||||
// Compute pooled output (mean pooling)
|
||||
println!("\nComputing pooled output (mean pooling)...");
|
||||
let start = Instant::now();
|
||||
let pooled_output = mean_pool(last_hidden_state.clone(), attention_mask.clone());
|
||||
let pooling_time = start.elapsed();
|
||||
println!(" Pooling completed in {:.2?}", pooling_time);
|
||||
|
||||
// Display output shapes
|
||||
let last_hidden_shape = last_hidden_state.shape();
|
||||
let pooled_shape = pooled_output.shape();
|
||||
println!("\n Model output shapes:");
|
||||
println!(" last_hidden_state: {:?}", last_hidden_shape.dims);
|
||||
println!(" pooled_output: {:?}", pooled_shape.dims);
|
||||
|
||||
// Verify expected output shapes match
|
||||
if last_hidden_shape.dims == ref_last_hidden_shape.dims {
|
||||
println!(
|
||||
" ✓ last_hidden_state shape matches expected: {:?}",
|
||||
ref_last_hidden_shape.dims
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" ⚠ Warning: Expected last_hidden_state shape {:?}, got {:?}",
|
||||
ref_last_hidden_shape.dims, last_hidden_shape.dims
|
||||
);
|
||||
}
|
||||
|
||||
if pooled_shape.dims == ref_pooled_shape.dims {
|
||||
println!(
|
||||
" ✓ pooled_output shape matches expected: {:?}",
|
||||
ref_pooled_shape.dims
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
" ⚠ Warning: Expected pooled_output shape {:?}, got {:?}",
|
||||
ref_pooled_shape.dims, pooled_shape.dims
|
||||
);
|
||||
}
|
||||
|
||||
// Compare outputs
|
||||
println!("\nComparing model outputs with reference data...");
|
||||
|
||||
// Check if last_hidden_state is close
|
||||
println!("\n Checking last_hidden_state:");
|
||||
if last_hidden_state.clone().all_close(
|
||||
reference_last_hidden_state.clone(),
|
||||
Some(1e-4),
|
||||
Some(1e-4),
|
||||
) {
|
||||
println!(" ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ last_hidden_state differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = last_hidden_state.clone() - reference_last_hidden_state.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = last_hidden_state.clone().flatten::<1>(0, 2);
|
||||
let reference_flat = reference_last_hidden_state.clone().flatten::<1>(0, 2);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if pooled_output is close
|
||||
println!("\n Checking pooled_output:");
|
||||
if pooled_output
|
||||
.clone()
|
||||
.all_close(reference_pooled_output.clone(), Some(1e-4), Some(1e-4))
|
||||
{
|
||||
println!(" ✓ pooled_output matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ pooled_output differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = pooled_output.clone() - reference_pooled_output.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = pooled_output.clone().flatten::<1>(0, 1);
|
||||
let reference_flat = reference_pooled_output.clone().flatten::<1>(0, 1);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n========================================");
|
||||
println!("Model test completed!");
|
||||
println!("========================================");
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "burn-onnx-model-checks-rf-detr"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[workspace]
|
||||
|
||||
[features]
|
||||
default = ["ndarray"]
|
||||
ndarray = []
|
||||
tch = []
|
||||
wgpu = []
|
||||
metal = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../../crates/burn", features = [
|
||||
"ndarray",
|
||||
"tch",
|
||||
"wgpu",
|
||||
"metal",
|
||||
] }
|
||||
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../../../burn-onnx" }
|
||||
@@ -1,85 +0,0 @@
|
||||
# RF-DETR Model Check
|
||||
|
||||
This crate tests burn-onnx's ability to handle the RF-DETR (Roboflow DETR) object detection model.
|
||||
|
||||
## About RF-DETR
|
||||
|
||||
RF-DETR is a real-time object detection model based on the DETR (Detection Transformer)
|
||||
architecture, developed by Roboflow. It combines transformer-based detection with optimizations for
|
||||
speed and accuracy.
|
||||
|
||||
Key features:
|
||||
|
||||
- Transformer-based architecture with multi-head attention
|
||||
- Deformable attention mechanisms
|
||||
- Object queries for detection
|
||||
- End-to-end trainable without anchor boxes or NMS (Non-Maximum Suppression)
|
||||
|
||||
## Related Issue
|
||||
|
||||
This model check was created to track and test the fix for:
|
||||
|
||||
- [Issue #4052](https://github.com/tracel-ai/burn/issues/4052): RF-DETR ONNX import fails with "axis
|
||||
2 is out of bounds for rank 1"
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Download and prepare the model
|
||||
|
||||
Requires Python 3.11:
|
||||
|
||||
```bash
|
||||
# Using uv (recommended)
|
||||
uv run --python 3.11 get_model.py
|
||||
```
|
||||
|
||||
### 2. Build and run the model test
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
cargo run
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
rf-detr/
|
||||
├── artifacts/ # Downloaded ONNX model and test data
|
||||
│ ├── rf_detr_small.onnx # ONNX model (119 MB)
|
||||
│ ├── rf_detr_small_test_data.pt # Test input/output tensors (3.1 MB)
|
||||
│ └── node_info.json # ONNX node analysis
|
||||
├── src/
|
||||
│ └── main.rs # Test runner with output comparison
|
||||
├── build.rs # Build script that generates model code
|
||||
├── get_model.py # Model download, export, and test data generation
|
||||
└── Cargo.toml
|
||||
```
|
||||
|
||||
## Model Details
|
||||
|
||||
- **Model**: RF-DETR Small
|
||||
- **Input**: `[1, 3, 512, 512]` (RGB image)
|
||||
- **Outputs**:
|
||||
- `dets`: `[1, 300, 4]` - 300 bounding boxes (x, y, w, h)
|
||||
- `labels`: `[1, 300, 91]` - 300 class scores (91 COCO classes)
|
||||
- **Architecture**: DETR with deformable attention
|
||||
|
||||
## Test Data
|
||||
|
||||
The `get_model.py` script generates reference test data by:
|
||||
|
||||
1. Creating a reproducible random input tensor (seed 42)
|
||||
2. Running inference with ONNX Runtime
|
||||
3. Saving both input and outputs as PyTorch tensors
|
||||
|
||||
When the ONNX import issue is fixed, `cargo run` will:
|
||||
|
||||
1. Load the test data
|
||||
2. Run inference with the burn-generated model
|
||||
3. Compare outputs against ONNX Runtime reference within tolerance (1e-4)
|
||||
|
||||
## Notes
|
||||
|
||||
- The model uses transformer layers which test burn-onnx's handling of attention mechanisms
|
||||
- The model includes complex operations like multi-head attention and deformable convolutions
|
||||
- Currently fails at build time due to issue #4052 (axis out of bounds during type inference)
|
||||
@@ -1,29 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
use std::path::Path;
|
||||
|
||||
fn main() {
|
||||
let onnx_path = "artifacts/rf_detr_small.onnx";
|
||||
let test_data_path = "artifacts/rf_detr_small_test_data.pt";
|
||||
|
||||
// Tell Cargo to only rebuild if these files change
|
||||
println!("cargo:rerun-if-changed={}", onnx_path);
|
||||
println!("cargo:rerun-if-changed={}", test_data_path);
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
// Check if the ONNX model file exists
|
||||
if !Path::new(onnx_path).exists() {
|
||||
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
|
||||
eprintln!();
|
||||
eprintln!("Please run the following command to download and prepare the RF-DETR model:");
|
||||
eprintln!(" uv run --python 3.11 get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("This will download and export the RF-DETR Small model to ONNX format.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Generate the model code from the ONNX file
|
||||
ModelGen::new()
|
||||
.input(onnx_path)
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
}
|
||||
@@ -1,299 +0,0 @@
|
||||
#!/usr/bin/env -S uv run --python 3.11 --script
|
||||
|
||||
# /// script
|
||||
# python = "3.11"
|
||||
# dependencies = [
|
||||
# "torch==2.6.*",
|
||||
# "torchvision==0.21.*",
|
||||
# "onnx<1.17",
|
||||
# "onnxruntime",
|
||||
# "numpy",
|
||||
# "rfdetr[onnxexport]",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Download and prepare the RF-DETR model for testing with burn-onnx.
|
||||
|
||||
RF-DETR (Roboflow DETR) is a real-time object detection model that combines
|
||||
the DETR (Detection Transformer) architecture with optimizations for speed.
|
||||
|
||||
This script exports the RFDETRSmall model to ONNX format for testing burn-onnx's
|
||||
ability to handle transformer-based detection models.
|
||||
|
||||
Related issue: https://github.com/tracel-ai/burn/issues/4052
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
|
||||
|
||||
def get_input_shape(model):
|
||||
"""Extract input shape from ONNX model."""
|
||||
input_info = model.graph.input[0]
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField("dim_value"):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append(1) # Default to 1 for dynamic dimensions
|
||||
|
||||
# Ensure valid RF-DETR input shape (batch, channels, height, width)
|
||||
# RF-DETR uses 560x560 by default
|
||||
if len(shape) != 4 or shape[2] == 0 or shape[2] > 2000:
|
||||
return [1, 3, 560, 560]
|
||||
return shape
|
||||
|
||||
|
||||
def extract_node_info(model_path, artifacts_dir):
|
||||
"""Extract node types and configurations from the ONNX model."""
|
||||
print("Extracting node information from ONNX model...")
|
||||
|
||||
# Load the ONNX model
|
||||
model = onnx.load(str(model_path))
|
||||
|
||||
# Collect node information
|
||||
node_types = defaultdict(int)
|
||||
node_details = []
|
||||
|
||||
def process_graph(graph, graph_name="main"):
|
||||
"""Recursively process a graph and its subgraphs."""
|
||||
for idx, node in enumerate(graph.node):
|
||||
node_types[node.op_type] += 1
|
||||
|
||||
# Extract node details
|
||||
node_info = {
|
||||
"graph": graph_name,
|
||||
"index": idx,
|
||||
"op_type": node.op_type,
|
||||
"name": node.name if node.name else f"{node.op_type}_{idx}",
|
||||
"inputs": list(node.input),
|
||||
"outputs": list(node.output),
|
||||
"attributes": {},
|
||||
}
|
||||
|
||||
# Extract attributes
|
||||
for attr in node.attribute:
|
||||
attr_name = attr.name
|
||||
# Get attribute value based on type
|
||||
if attr.HasField("f"):
|
||||
node_info["attributes"][attr_name] = float(attr.f)
|
||||
elif attr.HasField("i"):
|
||||
node_info["attributes"][attr_name] = int(attr.i)
|
||||
elif attr.HasField("s"):
|
||||
node_info["attributes"][attr_name] = (
|
||||
attr.s.decode("utf-8") if attr.s else ""
|
||||
)
|
||||
elif attr.HasField("t"):
|
||||
node_info["attributes"][attr_name] = "<tensor>"
|
||||
elif attr.floats:
|
||||
node_info["attributes"][attr_name] = list(attr.floats)
|
||||
elif attr.ints:
|
||||
node_info["attributes"][attr_name] = list(attr.ints)
|
||||
elif attr.strings:
|
||||
node_info["attributes"][attr_name] = [
|
||||
s.decode("utf-8") for s in attr.strings
|
||||
]
|
||||
elif attr.HasField("g"):
|
||||
# Subgraph - recursively process it
|
||||
subgraph_name = f"{graph_name}.{node.op_type}_{idx}.{attr_name}"
|
||||
node_info["attributes"][attr_name] = f"<subgraph: {subgraph_name}>"
|
||||
process_graph(attr.g, subgraph_name)
|
||||
elif attr.graphs:
|
||||
subgraph_names = []
|
||||
for g_idx, subgraph in enumerate(attr.graphs):
|
||||
subgraph_name = (
|
||||
f"{graph_name}.{node.op_type}_{idx}.{attr_name}_{g_idx}"
|
||||
)
|
||||
subgraph_names.append(subgraph_name)
|
||||
process_graph(subgraph, subgraph_name)
|
||||
node_info["attributes"][attr_name] = (
|
||||
f"<subgraphs: {', '.join(subgraph_names)}>"
|
||||
)
|
||||
else:
|
||||
node_info["attributes"][attr_name] = "<unknown>"
|
||||
|
||||
node_details.append(node_info)
|
||||
|
||||
# Process the main graph
|
||||
process_graph(model.graph, "main")
|
||||
|
||||
# Create summary
|
||||
summary = {
|
||||
"model_name": model.graph.name,
|
||||
"opset_version": model.opset_import[0].version
|
||||
if model.opset_import
|
||||
else "unknown",
|
||||
"total_nodes": len(node_details),
|
||||
"node_type_counts": dict(sorted(node_types.items())),
|
||||
"nodes": node_details,
|
||||
}
|
||||
|
||||
# Save to JSON file
|
||||
output_path = artifacts_dir / "node_info.json"
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
print(f" Node information extracted to {output_path}")
|
||||
print(f" Total nodes: {summary['total_nodes']}")
|
||||
print(f" Unique node types: {len(node_types)}")
|
||||
print(f" Node type distribution:")
|
||||
for op_type, count in sorted(node_types.items(), key=lambda x: x[1], reverse=True)[
|
||||
:15
|
||||
]:
|
||||
print(f" - {op_type}: {count}")
|
||||
if len(node_types) > 15:
|
||||
print(f" ... and {len(node_types) - 15} more types")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def generate_test_data(model_path, output_path):
|
||||
"""Generate test input/output data and save as PyTorch tensors."""
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
print("Generating test data...")
|
||||
|
||||
# Load model to get input shape
|
||||
model = onnx.load(str(model_path))
|
||||
input_shape = get_input_shape(model)
|
||||
print(f" Input shape: {input_shape}")
|
||||
|
||||
# Create reproducible test input
|
||||
np.random.seed(42)
|
||||
test_input = np.random.rand(*input_shape).astype(np.float32)
|
||||
|
||||
# Run inference to get output
|
||||
session = ort.InferenceSession(str(model_path))
|
||||
input_name = session.get_inputs()[0].name
|
||||
outputs = session.run(None, {input_name: test_input})
|
||||
|
||||
# RF-DETR has two outputs: dets (boxes) and labels (class scores)
|
||||
# Save as PyTorch tensors
|
||||
test_data = {
|
||||
"input": torch.from_numpy(test_input),
|
||||
"output_dets": torch.from_numpy(outputs[0]),
|
||||
"output_labels": torch.from_numpy(outputs[1]),
|
||||
}
|
||||
|
||||
torch.save(test_data, output_path)
|
||||
|
||||
print(f" Test data saved to: {output_path}")
|
||||
print(f" Input shape: {test_input.shape}")
|
||||
print(f" Output dets shape: {outputs[0].shape}")
|
||||
print(f" Output labels shape: {outputs[1].shape}")
|
||||
|
||||
|
||||
def download_and_export_model():
|
||||
"""Download RF-DETR model and export to ONNX format."""
|
||||
from rfdetr import RFDETRSmall
|
||||
|
||||
# Create artifacts directory
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
model_path = artifacts_dir / "rf_detr_small.onnx"
|
||||
test_data_path = artifacts_dir / "rf_detr_small_test_data.pt"
|
||||
|
||||
# Check if we already have everything
|
||||
if model_path.exists() and test_data_path.exists():
|
||||
print(f"All files already exist:")
|
||||
print(f" Model: {model_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print("\nTo re-download, delete the artifacts directory and run again.")
|
||||
return
|
||||
|
||||
print("=" * 60)
|
||||
print("RF-DETR Small Model Preparation Tool")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Download and export if model doesn't exist
|
||||
if not model_path.exists():
|
||||
# Download and initialize the model
|
||||
print("Step 1: Downloading RF-DETR Small model...")
|
||||
model = RFDETRSmall()
|
||||
print(" Model downloaded and initialized")
|
||||
|
||||
# Export to ONNX using RF-DETR's built-in export method
|
||||
print()
|
||||
print("Step 2: Exporting model to ONNX format...")
|
||||
|
||||
# RF-DETR exports to output/inference_model.onnx by default
|
||||
# Note: model.export() returns None, but exports to output/inference_model.onnx
|
||||
model.export()
|
||||
default_export_path = Path("output/inference_model.onnx")
|
||||
print(f" Exported to: {default_export_path}")
|
||||
|
||||
# Move the exported file to artifacts directory
|
||||
exported_file = default_export_path
|
||||
if exported_file.exists():
|
||||
shutil.move(str(exported_file), str(model_path))
|
||||
# Clean up the output directory if empty
|
||||
output_dir = exported_file.parent
|
||||
if output_dir.exists() and not any(output_dir.iterdir()):
|
||||
output_dir.rmdir()
|
||||
|
||||
# Clean up any downloaded weights file
|
||||
pth_files = list(Path(".").glob("*.pth"))
|
||||
for pth_file in pth_files:
|
||||
pth_file.unlink()
|
||||
print(f" Cleaned up: {pth_file}")
|
||||
|
||||
print(f" Model saved to {model_path}")
|
||||
print(f" File size: {model_path.stat().st_size / 1024 / 1024:.1f} MB")
|
||||
|
||||
# Extract node information
|
||||
print()
|
||||
print("Step 3: Analyzing ONNX model structure...")
|
||||
extract_node_info(model_path, artifacts_dir)
|
||||
|
||||
# Generate test data if needed
|
||||
if not test_data_path.exists():
|
||||
print()
|
||||
print("Step 4: Generating test data...")
|
||||
generate_test_data(model_path, test_data_path)
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("RF-DETR model preparation completed!")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("The RF-DETR model is a transformer-based object detector that uses:")
|
||||
print(" - Multi-head self-attention layers")
|
||||
print(" - Cross-attention for object queries")
|
||||
print(" - Deformable attention mechanisms")
|
||||
print()
|
||||
print("Generated files:")
|
||||
print(f" - {model_path} (ONNX model)")
|
||||
print(f" - {test_data_path} (test input/output data)")
|
||||
print(f" - {artifacts_dir / 'node_info.json'} (node analysis)")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(" 1. Build the model: cargo build")
|
||||
print(" 2. Run the test: cargo run")
|
||||
print()
|
||||
print("Note: This model is used to test burn-onnx's handling of")
|
||||
print("transformer-based architectures. Related issue:")
|
||||
print(" https://github.com/tracel-ai/burn/issues/4052")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
download_and_export_model()
|
||||
except KeyboardInterrupt:
|
||||
print("\nOperation cancelled by user.")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
@@ -1,238 +0,0 @@
|
||||
extern crate alloc;
|
||||
|
||||
use burn::module::{Initializer, Param};
|
||||
use burn::prelude::*;
|
||||
|
||||
use burn_store::PytorchStore;
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type MyBackend = burn::backend::Wgpu;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type MyBackend = burn::backend::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub type MyBackend = burn::backend::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub type MyBackend = burn::backend::Metal;
|
||||
|
||||
// Include the generated model
|
||||
include!(concat!(env!("OUT_DIR"), "/model/rf_detr_small.rs"));
|
||||
|
||||
/// Test data structure matching the PyTorch saved format
|
||||
/// RF-DETR has two outputs: dets (bounding boxes) and labels (class scores)
|
||||
#[derive(Debug, Module)]
|
||||
struct TestData<B: Backend> {
|
||||
input: Param<Tensor<B, 4>>,
|
||||
output_dets: Param<Tensor<B, 3>>,
|
||||
output_labels: Param<Tensor<B, 3>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TestData<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
// RF-DETR Small: input 512x512, 300 queries, 4 bbox coords, 91 classes (COCO)
|
||||
Self {
|
||||
input: Initializer::Zeros.init([1, 3, 512, 512], device),
|
||||
output_dets: Initializer::Zeros.init([1, 300, 4], device),
|
||||
output_labels: Initializer::Zeros.init([1, 300, 91], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("========================================");
|
||||
println!("RF-DETR Small Model Test");
|
||||
println!("========================================\n");
|
||||
|
||||
// Check if artifacts exist
|
||||
let artifacts_dir = Path::new("artifacts");
|
||||
if !artifacts_dir.exists() {
|
||||
eprintln!("Error: artifacts directory not found!");
|
||||
eprintln!("Please run get_model.py first to download the model.");
|
||||
eprintln!("Example: uv run --python 3.11 get_model.py");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Check if model file exists
|
||||
let model_file = artifacts_dir.join("rf_detr_small.onnx");
|
||||
let test_data_file = artifacts_dir.join("rf_detr_small_test_data.pt");
|
||||
|
||||
if !model_file.exists() {
|
||||
eprintln!("Error: Model file not found!");
|
||||
eprintln!("Please run: uv run --python 3.11 get_model.py");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
if !test_data_file.exists() {
|
||||
eprintln!("Error: Test data file not found!");
|
||||
eprintln!("Please run: uv run --python 3.11 get_model.py");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Initialize the model with weights
|
||||
println!("Initializing RF-DETR Small model...");
|
||||
let start = Instant::now();
|
||||
let device = Default::default();
|
||||
|
||||
// The model weights are generated at build time and stored in the OUT_DIR
|
||||
// We need to load them from the embedded burnpack file
|
||||
let weights_path = concat!(env!("OUT_DIR"), "/model/rf_detr_small.bpk");
|
||||
let model: Model<MyBackend> = Model::from_file(weights_path, &device);
|
||||
let init_time = start.elapsed();
|
||||
println!(" Model initialized in {:.2?}", init_time);
|
||||
|
||||
// Save model structure to file
|
||||
let model_txt_path = artifacts_dir.join("rf_detr_small_model.txt");
|
||||
println!(
|
||||
"\nSaving model structure to {}...",
|
||||
model_txt_path.display()
|
||||
);
|
||||
let model_str = format!("{}", model);
|
||||
std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file");
|
||||
println!(" Model structure saved");
|
||||
|
||||
// Load test data from PyTorch file
|
||||
println!("\nLoading test data from {}...", test_data_file.display());
|
||||
let start = Instant::now();
|
||||
let mut test_data = TestData::<MyBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file(&test_data_file);
|
||||
test_data.load_from(&mut store).expect("Failed to load test data");
|
||||
let load_time = start.elapsed();
|
||||
println!(" Data loaded in {:.2?}", load_time);
|
||||
|
||||
// Get the input tensor from test data
|
||||
let input = test_data.input.val();
|
||||
let input_shape = input.shape();
|
||||
println!(" Loaded input tensor with shape: {:?}", input_shape.dims);
|
||||
|
||||
// Get the reference outputs from test data
|
||||
let reference_dets = test_data.output_dets.val();
|
||||
let reference_labels = test_data.output_labels.val();
|
||||
println!(
|
||||
" Loaded reference dets with shape: {:?}",
|
||||
reference_dets.shape().dims
|
||||
);
|
||||
println!(
|
||||
" Loaded reference labels with shape: {:?}",
|
||||
reference_labels.shape().dims
|
||||
);
|
||||
|
||||
// Run inference with the loaded input
|
||||
println!("\nRunning model inference with test input...");
|
||||
let start = Instant::now();
|
||||
let (output_dets, output_labels) = model.forward(input);
|
||||
let inference_time = start.elapsed();
|
||||
println!(" Inference completed in {:.2?}", inference_time);
|
||||
|
||||
// Display output shapes
|
||||
println!("\nModel outputs:");
|
||||
println!(" Dets shape: {:?}", output_dets.shape().dims);
|
||||
println!(" Labels shape: {:?}", output_labels.shape().dims);
|
||||
|
||||
// Compare outputs
|
||||
println!("\nComparing model outputs with reference data...");
|
||||
|
||||
let mut dets_passed = false;
|
||||
let mut labels_passed = false;
|
||||
|
||||
// Check if dets are close
|
||||
println!("\n Checking dets (bounding boxes):");
|
||||
if output_dets
|
||||
.clone()
|
||||
.all_close(reference_dets.clone(), Some(1e-4), Some(1e-4))
|
||||
{
|
||||
println!(" ✓ dets matches reference data within tolerance (1e-4)!");
|
||||
dets_passed = true;
|
||||
} else {
|
||||
println!(" ⚠ dets differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = output_dets.clone() - reference_dets.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = output_dets.clone().flatten::<1>(0, 2);
|
||||
let reference_flat = reference_dets.clone().flatten::<1>(0, 2);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if labels are close
|
||||
println!("\n Checking labels (class scores):");
|
||||
if output_labels
|
||||
.clone()
|
||||
.all_close(reference_labels.clone(), Some(1e-4), Some(1e-4))
|
||||
{
|
||||
println!(" ✓ labels matches reference data within tolerance (1e-4)!");
|
||||
labels_passed = true;
|
||||
} else {
|
||||
println!(" ⚠ labels differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = output_labels.clone() - reference_labels.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = output_labels.clone().flatten::<1>(0, 2);
|
||||
let reference_flat = reference_labels.clone().flatten::<1>(0, 2);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n========================================");
|
||||
println!("Summary:");
|
||||
println!(" - Model initialization: {:.2?}", init_time);
|
||||
println!(" - Data loading: {:.2?}", load_time);
|
||||
println!(" - Inference time: {:.2?}", inference_time);
|
||||
if dets_passed && labels_passed {
|
||||
println!(" - Output validation: ✓ All outputs match!");
|
||||
} else {
|
||||
println!(
|
||||
" - Output validation: {} dets, {} labels",
|
||||
if dets_passed { "✓" } else { "✗" },
|
||||
if labels_passed { "✓" } else { "✗" }
|
||||
);
|
||||
}
|
||||
println!("========================================");
|
||||
if dets_passed && labels_passed {
|
||||
println!("Model test completed successfully!");
|
||||
} else {
|
||||
println!("Model test completed with differences.");
|
||||
}
|
||||
println!("========================================");
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
[package]
|
||||
name = "burn-onnx-model-checks-silero-vad"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[workspace]
|
||||
|
||||
[features]
|
||||
default = ["ndarray"]
|
||||
ndarray = []
|
||||
tch = []
|
||||
wgpu = []
|
||||
metal = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../../crates/burn", features = [
|
||||
"ndarray",
|
||||
"tch",
|
||||
"wgpu",
|
||||
"metal",
|
||||
] }
|
||||
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../../../burn-onnx" }
|
||||
@@ -1,102 +0,0 @@
|
||||
# Silero VAD Model Check
|
||||
|
||||
This model check verifies that burn-onnx can correctly handle the Silero VAD (Voice Activity
|
||||
Detection) model and produces outputs matching ONNX Runtime.
|
||||
|
||||
## Current Status
|
||||
|
||||
**Working**: The model is successfully imported and produces outputs matching ONNX Runtime.
|
||||
|
||||
## Model Information
|
||||
|
||||
- **Model**: Silero VAD (opset 18, if-less version)
|
||||
- **Source**: https://github.com/snakers4/silero-vad
|
||||
- **Purpose**: Voice Activity Detection
|
||||
- **Key Features**: Uses Conv, Gemm, and a single If node for sample rate selection
|
||||
|
||||
The if-less version has only 1 If node (for 16kHz vs 8kHz sample rate selection), making it
|
||||
much simpler than the full model which has 25 If nodes.
|
||||
|
||||
See: https://github.com/snakers4/silero-vad/issues/728 for compatibility discussion.
|
||||
|
||||
## Setup
|
||||
|
||||
### Step 1: Download the Model and Generate Reference Outputs
|
||||
|
||||
```bash
|
||||
# Using Python
|
||||
python get_model.py
|
||||
|
||||
# Or using uv
|
||||
uv run get_model.py
|
||||
```
|
||||
|
||||
This downloads:
|
||||
- `artifacts/silero_vad.onnx` - The ONNX model file
|
||||
- `artifacts/node_info.json` - Detailed analysis of all nodes, operators, and configurations
|
||||
- `artifacts/test.wav` - Test audio file from silero-vad repository
|
||||
- `artifacts/reference_outputs.json` - Reference outputs from ONNX Runtime for validation
|
||||
|
||||
### Step 2: Build and Run Tests
|
||||
|
||||
```bash
|
||||
cargo build
|
||||
cargo run
|
||||
```
|
||||
|
||||
The test suite runs 12 test cases:
|
||||
- 10 audio chunks from the test.wav file
|
||||
- 1 random input test (reproducible with seed 42)
|
||||
- 1 silence test (all zeros)
|
||||
|
||||
Each test compares the Burn model output against ONNX Runtime reference outputs with a 1% tolerance.
|
||||
|
||||
## Backend Support
|
||||
|
||||
This model check supports multiple backends:
|
||||
|
||||
```bash
|
||||
# NdArray backend (default, CPU)
|
||||
cargo run
|
||||
|
||||
# LibTorch backend (CPU/CUDA)
|
||||
cargo run --features tch --no-default-features
|
||||
|
||||
# WGPU backend (GPU via WebGPU)
|
||||
cargo run --features wgpu --no-default-features
|
||||
|
||||
# Metal backend (Apple Silicon GPU)
|
||||
cargo run --features metal --no-default-features
|
||||
```
|
||||
|
||||
## Test Output Example
|
||||
|
||||
```
|
||||
========================================
|
||||
Silero VAD Model Test Suite
|
||||
========================================
|
||||
|
||||
Loading reference outputs...
|
||||
Loaded 12 test cases (sample rate: 16000 Hz)
|
||||
|
||||
Initializing Silero VAD model...
|
||||
Model initialized
|
||||
|
||||
Running test cases...
|
||||
------------------------------------------------------------
|
||||
[PASS] chunk_0: output=0.000589 (expected=0.000589)
|
||||
[PASS] chunk_1: output=0.000589 (expected=0.000528)
|
||||
...
|
||||
[PASS] silence: output=0.000592 (expected=0.000592)
|
||||
------------------------------------------------------------
|
||||
|
||||
========================================
|
||||
Test Summary
|
||||
========================================
|
||||
Total tests: 12
|
||||
Passed: 12
|
||||
Failed: 0
|
||||
|
||||
All tests passed!
|
||||
The Burn model produces outputs matching ONNX Runtime.
|
||||
```
|
||||
@@ -1,30 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
use std::path::Path;
|
||||
|
||||
fn main() {
|
||||
let onnx_path = "artifacts/silero_vad.onnx";
|
||||
|
||||
// Tell Cargo to only rebuild if these files change
|
||||
println!("cargo:rerun-if-changed={}", onnx_path);
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
|
||||
// Check if the ONNX model file exists
|
||||
if !Path::new(onnx_path).exists() {
|
||||
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
|
||||
eprintln!();
|
||||
eprintln!("Please run the following command to download the model:");
|
||||
eprintln!(" python get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("Or if you prefer using uv:");
|
||||
eprintln!(" uv run get_model.py");
|
||||
eprintln!();
|
||||
eprintln!("This will download the Silero VAD model.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Generate the model code from the ONNX file
|
||||
ModelGen::new()
|
||||
.input(onnx_path)
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
}
|
||||
@@ -1,399 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Download and prepare the Silero VAD model for testing.
|
||||
|
||||
This script downloads the Silero VAD ONNX model (opset 18, if-less version) and prepares
|
||||
it for use with burn-onnx. This version has only 1 If node (for sample rate selection)
|
||||
making it compatible with static type inference.
|
||||
|
||||
See: https://github.com/snakers4/silero-vad/issues/728 for compatibility discussion.
|
||||
"""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import urllib.request
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import onnx
|
||||
except ImportError:
|
||||
print("Error: onnx package not found. Please install it with:")
|
||||
print(" pip install onnx")
|
||||
exit(1)
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
print("Error: onnxruntime package not found. Please install it with:")
|
||||
print(" pip install onnxruntime")
|
||||
exit(1)
|
||||
|
||||
|
||||
def extract_node_info(model_path, artifacts_dir):
|
||||
"""Extract node types and configurations from the ONNX model."""
|
||||
print("Extracting node information from ONNX model...")
|
||||
|
||||
# Load the ONNX model (without external data since we only need structure)
|
||||
model = onnx.load(str(model_path), load_external_data=False)
|
||||
|
||||
# Check for external data
|
||||
external_files = set()
|
||||
for init in model.graph.initializer:
|
||||
if init.data_location == onnx.TensorProto.EXTERNAL:
|
||||
for ext_data in init.external_data:
|
||||
if ext_data.key == 'location':
|
||||
external_files.add(ext_data.value)
|
||||
|
||||
if external_files:
|
||||
print(f"⚠️ Model requires external data files: {external_files}")
|
||||
print(" These files are missing from the repository!")
|
||||
|
||||
# Collect node information
|
||||
node_types = defaultdict(int)
|
||||
node_details = []
|
||||
|
||||
def process_graph(graph, graph_name="main"):
|
||||
"""Recursively process a graph and its subgraphs."""
|
||||
for idx, node in enumerate(graph.node):
|
||||
node_types[node.op_type] += 1
|
||||
|
||||
# Extract node details
|
||||
node_info = {
|
||||
"graph": graph_name,
|
||||
"index": idx,
|
||||
"op_type": node.op_type,
|
||||
"name": node.name if node.name else f"{node.op_type}_{idx}",
|
||||
"inputs": list(node.input),
|
||||
"outputs": list(node.output),
|
||||
"attributes": {}
|
||||
}
|
||||
|
||||
# Extract attributes
|
||||
for attr in node.attribute:
|
||||
attr_name = attr.name
|
||||
# Get attribute value based on type
|
||||
if attr.HasField('f'):
|
||||
node_info["attributes"][attr_name] = float(attr.f)
|
||||
elif attr.HasField('i'):
|
||||
node_info["attributes"][attr_name] = int(attr.i)
|
||||
elif attr.HasField('s'):
|
||||
node_info["attributes"][attr_name] = attr.s.decode('utf-8') if attr.s else ""
|
||||
elif attr.HasField('t'):
|
||||
node_info["attributes"][attr_name] = "<tensor>"
|
||||
elif attr.floats:
|
||||
node_info["attributes"][attr_name] = list(attr.floats)
|
||||
elif attr.ints:
|
||||
node_info["attributes"][attr_name] = list(attr.ints)
|
||||
elif attr.strings:
|
||||
node_info["attributes"][attr_name] = [s.decode('utf-8') for s in attr.strings]
|
||||
elif attr.HasField('g'):
|
||||
# Subgraph - recursively process it
|
||||
subgraph_name = f"{graph_name}.{node.op_type}_{idx}.{attr_name}"
|
||||
node_info["attributes"][attr_name] = f"<subgraph: {subgraph_name}>"
|
||||
process_graph(attr.g, subgraph_name)
|
||||
elif attr.graphs:
|
||||
subgraph_names = []
|
||||
for g_idx, subgraph in enumerate(attr.graphs):
|
||||
subgraph_name = f"{graph_name}.{node.op_type}_{idx}.{attr_name}_{g_idx}"
|
||||
subgraph_names.append(subgraph_name)
|
||||
process_graph(subgraph, subgraph_name)
|
||||
node_info["attributes"][attr_name] = f"<subgraphs: {', '.join(subgraph_names)}>"
|
||||
else:
|
||||
node_info["attributes"][attr_name] = "<unknown>"
|
||||
|
||||
node_details.append(node_info)
|
||||
|
||||
# Process the main graph
|
||||
process_graph(model.graph, "main")
|
||||
|
||||
# Create summary
|
||||
summary = {
|
||||
"model_name": model.graph.name,
|
||||
"opset_version": model.opset_import[0].version if model.opset_import else "unknown",
|
||||
"total_nodes": len(node_details),
|
||||
"node_type_counts": dict(sorted(node_types.items())),
|
||||
"external_data_files": list(external_files),
|
||||
"nodes": node_details
|
||||
}
|
||||
|
||||
# Save to JSON file
|
||||
output_path = artifacts_dir / "node_info.json"
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
|
||||
print(f"✓ Node information extracted to {output_path}")
|
||||
print(f" Opset version: {summary['opset_version']}")
|
||||
print(f" Total nodes: {summary['total_nodes']}")
|
||||
print(f" Unique node types: {len(node_types)}")
|
||||
print(f" Node type distribution:")
|
||||
for op_type, count in sorted(node_types.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" - {op_type}: {count}")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def download_test_data():
|
||||
"""Download test audio file from silero-vad repository."""
|
||||
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
test_wav_path = artifacts_dir / "test.wav"
|
||||
|
||||
if test_wav_path.exists():
|
||||
print(f"✓ Test audio already exists at {test_wav_path}")
|
||||
return test_wav_path
|
||||
|
||||
# Download the test.wav file from silero-vad tests
|
||||
test_url = "https://github.com/snakers4/silero-vad/raw/refs/heads/master/tests/data/test.wav"
|
||||
|
||||
print(f"Downloading test audio from:")
|
||||
print(f" {test_url}")
|
||||
print(f"Saving to: {test_wav_path}")
|
||||
|
||||
try:
|
||||
urllib.request.urlretrieve(test_url, test_wav_path)
|
||||
file_size = test_wav_path.stat().st_size / 1024
|
||||
print(f"✓ Download complete! File size: {file_size:.1f} KB")
|
||||
except Exception as e:
|
||||
print(f"✗ Error downloading test audio: {e}")
|
||||
raise
|
||||
|
||||
return test_wav_path
|
||||
|
||||
|
||||
def load_wav(wav_path):
|
||||
"""Load a WAV file and return audio samples as float32 array normalized to [-1, 1]."""
|
||||
with wave.open(str(wav_path), 'rb') as wav_file:
|
||||
n_channels = wav_file.getnchannels()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
frame_rate = wav_file.getframerate()
|
||||
n_frames = wav_file.getnframes()
|
||||
|
||||
# Read raw bytes
|
||||
raw_data = wav_file.readframes(n_frames)
|
||||
|
||||
# Convert to numpy array based on sample width
|
||||
if sample_width == 2: # 16-bit
|
||||
samples = np.frombuffer(raw_data, dtype=np.int16)
|
||||
# Normalize to [-1, 1]
|
||||
samples = samples.astype(np.float32) / 32768.0
|
||||
elif sample_width == 4: # 32-bit
|
||||
samples = np.frombuffer(raw_data, dtype=np.int32)
|
||||
samples = samples.astype(np.float32) / 2147483648.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported sample width: {sample_width}")
|
||||
|
||||
# If stereo, convert to mono by averaging channels
|
||||
if n_channels == 2:
|
||||
samples = samples.reshape(-1, 2).mean(axis=1)
|
||||
|
||||
return samples, frame_rate
|
||||
|
||||
|
||||
def generate_reference_outputs(model_path, test_wav_path, artifacts_dir):
|
||||
"""Generate reference outputs using ONNX Runtime for testing."""
|
||||
|
||||
print(f" Loading test audio from {test_wav_path}...")
|
||||
audio_samples, sample_rate = load_wav(test_wav_path)
|
||||
print(f" Sample rate: {sample_rate} Hz")
|
||||
print(f" Audio length: {len(audio_samples)} samples ({len(audio_samples)/sample_rate:.2f} seconds)")
|
||||
|
||||
# Create ONNX Runtime session
|
||||
print(f" Creating ONNX Runtime session...")
|
||||
session = ort.InferenceSession(str(model_path), providers=['CPUExecutionProvider'])
|
||||
|
||||
# Silero VAD parameters
|
||||
# For 16kHz: chunk_size = 512 (32ms)
|
||||
# For 8kHz: chunk_size = 256 (32ms)
|
||||
chunk_size = 512 if sample_rate == 16000 else 256
|
||||
batch_size = 1
|
||||
|
||||
# Initialize state
|
||||
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
||||
|
||||
# Process audio in chunks and collect outputs
|
||||
results = {
|
||||
"sample_rate": sample_rate,
|
||||
"chunk_size": chunk_size,
|
||||
"audio_length_samples": len(audio_samples),
|
||||
"test_cases": []
|
||||
}
|
||||
|
||||
# Test Case 1: First few chunks from the beginning of the audio
|
||||
print(f" Running inference on test chunks...")
|
||||
num_test_chunks = min(10, len(audio_samples) // chunk_size)
|
||||
|
||||
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
||||
for i in range(num_test_chunks):
|
||||
start_idx = i * chunk_size
|
||||
end_idx = start_idx + chunk_size
|
||||
|
||||
chunk = audio_samples[start_idx:end_idx]
|
||||
if len(chunk) < chunk_size:
|
||||
# Pad with zeros if needed
|
||||
chunk = np.pad(chunk, (0, chunk_size - len(chunk)), mode='constant')
|
||||
|
||||
# Prepare inputs
|
||||
input_tensor = chunk.reshape(1, -1).astype(np.float32)
|
||||
sr_tensor = np.array(sample_rate, dtype=np.int64)
|
||||
|
||||
# Run inference
|
||||
outputs = session.run(None, {
|
||||
'input': input_tensor,
|
||||
'sr': sr_tensor,
|
||||
'state': state
|
||||
})
|
||||
|
||||
output_prob = float(outputs[0].flatten()[0])
|
||||
state = outputs[1]
|
||||
|
||||
results["test_cases"].append({
|
||||
"test_name": f"chunk_{i}",
|
||||
"chunk_index": i,
|
||||
"start_sample": start_idx,
|
||||
"input_samples": chunk.tolist(),
|
||||
"expected_output": output_prob,
|
||||
"state_after": state.tolist()
|
||||
})
|
||||
|
||||
# Test Case 2: Random input (for reproducibility test)
|
||||
print(f" Generating random input test case...")
|
||||
np.random.seed(42)
|
||||
random_input = np.random.randn(chunk_size).astype(np.float32) * 0.1
|
||||
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
||||
|
||||
outputs = session.run(None, {
|
||||
'input': random_input.reshape(1, -1),
|
||||
'sr': np.array(sample_rate, dtype=np.int64),
|
||||
'state': state
|
||||
})
|
||||
|
||||
results["test_cases"].append({
|
||||
"test_name": "random_seed_42",
|
||||
"chunk_index": -1,
|
||||
"start_sample": -1,
|
||||
"input_samples": random_input.tolist(),
|
||||
"expected_output": float(outputs[0].flatten()[0]),
|
||||
"state_after": outputs[1].tolist()
|
||||
})
|
||||
|
||||
# Test Case 3: Zero input (silence)
|
||||
print(f" Generating silence test case...")
|
||||
zero_input = np.zeros(chunk_size, dtype=np.float32)
|
||||
state = np.zeros((2, batch_size, 128), dtype=np.float32)
|
||||
|
||||
outputs = session.run(None, {
|
||||
'input': zero_input.reshape(1, -1),
|
||||
'sr': np.array(sample_rate, dtype=np.int64),
|
||||
'state': state
|
||||
})
|
||||
|
||||
results["test_cases"].append({
|
||||
"test_name": "silence",
|
||||
"chunk_index": -1,
|
||||
"start_sample": -1,
|
||||
"input_samples": zero_input.tolist(),
|
||||
"expected_output": float(outputs[0].flatten()[0]),
|
||||
"state_after": outputs[1].tolist()
|
||||
})
|
||||
|
||||
# Save results
|
||||
output_path = artifacts_dir / "reference_outputs.json"
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f" ✓ Reference outputs saved to {output_path}")
|
||||
print(f" Total test cases: {len(results['test_cases'])}")
|
||||
|
||||
|
||||
def download_model():
|
||||
"""Download the Silero VAD ONNX model (opset 18, if-less version)."""
|
||||
|
||||
# Create artifacts directory if it doesn't exist
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
model_path = artifacts_dir / "silero_vad.onnx"
|
||||
|
||||
model_existed = model_path.exists()
|
||||
|
||||
# Skip download if model already exists
|
||||
if model_existed:
|
||||
print(f"✓ Model already exists at {model_path}")
|
||||
print(f" File size: {model_path.stat().st_size / 1024:.1f} KB")
|
||||
print()
|
||||
else:
|
||||
# Download the opset 18 if-less model
|
||||
# Note: This model has external data that is missing from the repo
|
||||
model_url = "https://github.com/snakers4/silero-vad/raw/refs/heads/master/src/silero_vad/data/silero_vad_op18_ifless.onnx"
|
||||
|
||||
print(f"Downloading Silero VAD model (opset 18, if-less) from:")
|
||||
print(f" {model_url}")
|
||||
print(f"Saving to: {model_path}")
|
||||
print()
|
||||
|
||||
try:
|
||||
urllib.request.urlretrieve(model_url, model_path)
|
||||
file_size = model_path.stat().st_size / 1024
|
||||
print(f"✓ Download complete! File size: {file_size:.1f} KB")
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"✗ Error downloading model: {e}")
|
||||
raise
|
||||
|
||||
# Extract node information from the ONNX model
|
||||
try:
|
||||
extract_node_info(model_path, artifacts_dir)
|
||||
except Exception as e:
|
||||
print(f"✗ Error extracting node information: {e}")
|
||||
raise
|
||||
|
||||
print()
|
||||
|
||||
# Download test data
|
||||
print("Downloading test data...")
|
||||
try:
|
||||
test_wav_path = download_test_data()
|
||||
except Exception as e:
|
||||
print(f"✗ Error downloading test data: {e}")
|
||||
raise
|
||||
|
||||
print()
|
||||
|
||||
# Generate reference outputs
|
||||
print("Generating reference outputs...")
|
||||
try:
|
||||
generate_reference_outputs(model_path, test_wav_path, artifacts_dir)
|
||||
except Exception as e:
|
||||
print(f"✗ Error generating reference outputs: {e}")
|
||||
raise
|
||||
|
||||
print()
|
||||
print("="*80)
|
||||
print("Model preparation complete!")
|
||||
print("="*80)
|
||||
print()
|
||||
print("Silero VAD (opset 18, if-less) has only 1 If node for sample rate selection.")
|
||||
print("This makes it compatible with burn-onnx's static type inference.")
|
||||
print()
|
||||
print("See: https://github.com/snakers4/silero-vad/issues/728")
|
||||
print()
|
||||
print("Generated files:")
|
||||
print(f" - {model_path} (ONNX model)")
|
||||
print(f" - {artifacts_dir / 'node_info.json'} (node analysis)")
|
||||
print(f" - {test_wav_path} (test audio)")
|
||||
print(f" - {artifacts_dir / 'reference_outputs.json'} (reference test outputs)")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(" 1. Build the model: cargo build")
|
||||
print(" 2. Run the test: cargo run")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_model()
|
||||
@@ -1,164 +0,0 @@
|
||||
extern crate alloc;
|
||||
|
||||
use burn::prelude::*;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type MyBackend = burn::backend::Wgpu;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type MyBackend = burn::backend::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub type MyBackend = burn::backend::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub type MyBackend = burn::backend::Metal;
|
||||
|
||||
// Include the generated model
|
||||
include!(concat!(env!("OUT_DIR"), "/model/silero_vad.rs"));
|
||||
|
||||
/// Test case from reference outputs
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TestCase {
|
||||
test_name: String,
|
||||
#[allow(dead_code)]
|
||||
chunk_index: i32,
|
||||
#[allow(dead_code)]
|
||||
start_sample: i32,
|
||||
input_samples: Vec<f32>,
|
||||
expected_output: f32,
|
||||
#[allow(dead_code)]
|
||||
state_after: Vec<Vec<Vec<f32>>>,
|
||||
}
|
||||
|
||||
/// Reference outputs structure
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ReferenceOutputs {
|
||||
sample_rate: i64,
|
||||
#[allow(dead_code)]
|
||||
chunk_size: usize,
|
||||
#[allow(dead_code)]
|
||||
audio_length_samples: usize,
|
||||
test_cases: Vec<TestCase>,
|
||||
}
|
||||
|
||||
/// Run a single test case and return (passed, actual_output, expected_output)
|
||||
fn run_test_case(
|
||||
model: &Model<MyBackend>,
|
||||
device: &<MyBackend as Backend>::Device,
|
||||
test_case: &TestCase,
|
||||
sample_rate: i64,
|
||||
) -> (bool, f32, f32) {
|
||||
// Create input tensor from test case samples
|
||||
let input_data: Vec<f32> = test_case.input_samples.clone();
|
||||
let input = Tensor::<MyBackend, 1>::from_floats(input_data.as_slice(), device)
|
||||
.reshape([1, test_case.input_samples.len()]);
|
||||
|
||||
// Initialize state to zeros
|
||||
let state = Tensor::<MyBackend, 3>::zeros([2, 1, 128], device);
|
||||
|
||||
// Run inference
|
||||
let (output, _state_out) = model.forward(input, sample_rate, state);
|
||||
|
||||
// Get the output probability
|
||||
let actual_output: f32 = output.into_scalar();
|
||||
let expected_output = test_case.expected_output;
|
||||
|
||||
// Compare with tolerance (neural networks have small floating point differences)
|
||||
let tolerance = 0.01; // 1% tolerance
|
||||
let passed = (actual_output - expected_output).abs() < tolerance;
|
||||
|
||||
(passed, actual_output, expected_output)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("========================================");
|
||||
println!("Silero VAD Model Test Suite");
|
||||
println!("========================================\n");
|
||||
|
||||
// Check if artifacts exist
|
||||
let artifacts_dir = Path::new("artifacts");
|
||||
if !artifacts_dir.exists() {
|
||||
eprintln!("Error: artifacts directory not found!");
|
||||
eprintln!("Please run get_model.py first to download the model.");
|
||||
eprintln!("Example: uv run get_model.py");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Check if reference outputs exist
|
||||
let reference_path = artifacts_dir.join("reference_outputs.json");
|
||||
if !reference_path.exists() {
|
||||
eprintln!("Error: reference_outputs.json not found!");
|
||||
eprintln!("Please run: uv run get_model.py");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Load reference outputs
|
||||
println!("Loading reference outputs...");
|
||||
let reference_json = fs::read_to_string(&reference_path).expect("Failed to read reference outputs");
|
||||
let reference: ReferenceOutputs =
|
||||
serde_json::from_str(&reference_json).expect("Failed to parse reference outputs");
|
||||
println!(
|
||||
" Loaded {} test cases (sample rate: {} Hz)\n",
|
||||
reference.test_cases.len(),
|
||||
reference.sample_rate
|
||||
);
|
||||
|
||||
// Initialize the model
|
||||
println!("Initializing Silero VAD model...");
|
||||
let device = Default::default();
|
||||
let model: Model<MyBackend> = Model::default();
|
||||
println!(" Model initialized\n");
|
||||
|
||||
// Run tests
|
||||
println!("Running test cases...");
|
||||
println!("{:-<60}", "");
|
||||
|
||||
let mut passed_count = 0;
|
||||
let mut failed_count = 0;
|
||||
|
||||
for test_case in &reference.test_cases {
|
||||
let (passed, actual, expected) = run_test_case(&model, &device, test_case, reference.sample_rate);
|
||||
|
||||
if passed {
|
||||
println!(
|
||||
" [PASS] {}: output={:.6} (expected={:.6})",
|
||||
test_case.test_name, actual, expected
|
||||
);
|
||||
passed_count += 1;
|
||||
} else {
|
||||
println!(
|
||||
" [FAIL] {}: output={:.6} (expected={:.6}, diff={:.6})",
|
||||
test_case.test_name,
|
||||
actual,
|
||||
expected,
|
||||
(actual - expected).abs()
|
||||
);
|
||||
failed_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
println!("{:-<60}", "");
|
||||
println!();
|
||||
|
||||
// Summary
|
||||
println!("========================================");
|
||||
println!("Test Summary");
|
||||
println!("========================================");
|
||||
println!(" Total tests: {}", passed_count + failed_count);
|
||||
println!(" Passed: {}", passed_count);
|
||||
println!(" Failed: {}", failed_count);
|
||||
println!();
|
||||
|
||||
if failed_count == 0 {
|
||||
println!("All tests passed!");
|
||||
println!("The Burn model produces outputs matching ONNX Runtime.");
|
||||
} else {
|
||||
println!("Some tests failed!");
|
||||
println!("The Burn model outputs differ from ONNX Runtime.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
@@ -1,166 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Validate the Silero VAD ONNX model independently.
|
||||
|
||||
This script:
|
||||
1. Checks if the model is valid using onnx.checker
|
||||
2. Runs inference using ONNX Runtime to verify it works
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import onnx
|
||||
from onnx import checker
|
||||
except ImportError:
|
||||
print("Error: onnx package not found. Please install it with:")
|
||||
print(" pip install onnx")
|
||||
exit(1)
|
||||
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
print("Error: onnxruntime package not found. Please install it with:")
|
||||
print(" pip install onnxruntime")
|
||||
exit(1)
|
||||
|
||||
|
||||
def validate_model():
|
||||
"""Validate the Silero VAD ONNX model."""
|
||||
|
||||
model_path = Path("artifacts/silero_vad.onnx")
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model not found at {model_path}")
|
||||
print("Please run 'python get_model.py' first to download the model.")
|
||||
return False
|
||||
|
||||
print("=" * 80)
|
||||
print("Silero VAD ONNX Model Validation")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# Step 1: Load and check the model structure
|
||||
print("Step 1: Loading ONNX model...")
|
||||
try:
|
||||
model = onnx.load(str(model_path))
|
||||
print(f" ✓ Model loaded successfully")
|
||||
print(f" - Opset version: {model.opset_import[0].version}")
|
||||
print(f" - Graph name: {model.graph.name}")
|
||||
print(f" - Number of nodes: {len(model.graph.node)}")
|
||||
print(f" - Number of initializers: {len(model.graph.initializer)}")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to load model: {e}")
|
||||
return False
|
||||
print()
|
||||
|
||||
# Step 2: Validate with ONNX checker
|
||||
print("Step 2: Validating model with onnx.checker...")
|
||||
try:
|
||||
checker.check_model(model, full_check=True)
|
||||
print(" ✓ Model passed ONNX validation (full check)")
|
||||
except Exception as e:
|
||||
print(f" ✗ Model validation failed: {e}")
|
||||
return False
|
||||
print()
|
||||
|
||||
# Step 3: Print model inputs/outputs
|
||||
print("Step 3: Model inputs and outputs...")
|
||||
print(" Inputs:")
|
||||
for inp in model.graph.input:
|
||||
shape = []
|
||||
for dim in inp.type.tensor_type.shape.dim:
|
||||
if dim.dim_param:
|
||||
shape.append(dim.dim_param)
|
||||
else:
|
||||
shape.append(dim.dim_value)
|
||||
dtype = onnx.TensorProto.DataType.Name(inp.type.tensor_type.elem_type)
|
||||
print(f" - {inp.name}: {dtype} {shape}")
|
||||
|
||||
print(" Outputs:")
|
||||
for out in model.graph.output:
|
||||
shape = []
|
||||
for dim in out.type.tensor_type.shape.dim:
|
||||
if dim.dim_param:
|
||||
shape.append(dim.dim_param)
|
||||
else:
|
||||
shape.append(dim.dim_value)
|
||||
dtype = onnx.TensorProto.DataType.Name(out.type.tensor_type.elem_type)
|
||||
print(f" - {out.name}: {dtype} {shape}")
|
||||
print()
|
||||
|
||||
# Step 4: Run inference with ONNX Runtime
|
||||
print("Step 4: Running inference with ONNX Runtime...")
|
||||
try:
|
||||
# Create inference session
|
||||
session = ort.InferenceSession(str(model_path), providers=['CPUExecutionProvider'])
|
||||
print(" ✓ ONNX Runtime session created successfully")
|
||||
|
||||
# Get input details
|
||||
input_details = session.get_inputs()
|
||||
output_details = session.get_outputs()
|
||||
|
||||
print(f" - Session inputs: {[i.name for i in input_details]}")
|
||||
print(f" - Session outputs: {[o.name for o in output_details]}")
|
||||
|
||||
# Prepare sample inputs based on model signature
|
||||
# Silero VAD expects:
|
||||
# - input: audio chunk [batch, samples] - typically 512 samples for 16kHz
|
||||
# - sr: sample rate (int64)
|
||||
# - h: hidden state [2, batch, 64]
|
||||
# - c: cell state [2, batch, 64]
|
||||
|
||||
batch_size = 1
|
||||
# Silero VAD chunk sizes:
|
||||
# - 16kHz: 512 samples (32ms)
|
||||
# - 8kHz: 256 samples (32ms)
|
||||
# The model also supports larger chunks: 768, 1024, 1536 for 16kHz
|
||||
chunk_size = 512 # 512 samples for 16kHz (32ms)
|
||||
|
||||
# Prepare inputs based on Silero VAD documentation
|
||||
# https://github.com/snakers4/silero-vad
|
||||
inputs = {
|
||||
'input': np.random.randn(batch_size, chunk_size).astype(np.float32),
|
||||
'sr': np.array(16000, dtype=np.int64), # 16kHz sample rate
|
||||
'state': np.zeros((2, batch_size, 128), dtype=np.float32), # LSTM hidden state
|
||||
}
|
||||
|
||||
for name, value in inputs.items():
|
||||
print(f" - Input '{name}': shape={value.shape}, dtype={value.dtype}")
|
||||
|
||||
# Run inference
|
||||
print()
|
||||
print(" Running inference...")
|
||||
outputs = session.run(None, inputs)
|
||||
|
||||
print(" ✓ Inference completed successfully!")
|
||||
print()
|
||||
print(" Output values:")
|
||||
for i, (out_detail, out_value) in enumerate(zip(output_details, outputs)):
|
||||
print(f" - {out_detail.name}: shape={out_value.shape}, dtype={out_value.dtype}")
|
||||
if out_value.size <= 10:
|
||||
print(f" values: {out_value}")
|
||||
else:
|
||||
print(f" sample: {out_value.flat[:5]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ ONNX Runtime inference failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("✓ All validation checks passed!")
|
||||
print("=" * 80)
|
||||
print()
|
||||
print("The ONNX model is valid and can run inference successfully.")
|
||||
print("The issue with burn-onnx is in the code generation, not the model itself.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = validate_model()
|
||||
exit(0 if success else 1)
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "burn-onnx-model-checks-yolo"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
publish = false
|
||||
|
||||
[workspace]
|
||||
|
||||
[features]
|
||||
default = ["tch"]
|
||||
ndarray = []
|
||||
tch = []
|
||||
wgpu = []
|
||||
metal = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../../../crates/burn", features = [
|
||||
"ndarray",
|
||||
"tch",
|
||||
"wgpu",
|
||||
"metal",
|
||||
] }
|
||||
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../../../burn-onnx" }
|
||||
@@ -1,56 +0,0 @@
|
||||
# YOLO Model Checks
|
||||
|
||||
This crate provides a unified interface for testing multiple YOLO model variants with Burn.
|
||||
|
||||
## Supported Models
|
||||
|
||||
- `yolov5s` - YOLOv5 small variant
|
||||
- `yolov8n` - YOLOv8 nano variant
|
||||
- `yolov8s` - YOLOv8 small variant
|
||||
- `yolov10n` - YOLOv10 nano variant (Note: Currently fails due to TopK operator issue)
|
||||
- `yolo11x` - YOLO11 extra-large variant
|
||||
- `yolo12x` - YOLO12 extra-large variant
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Download and prepare a model
|
||||
|
||||
```bash
|
||||
# Using Python directly
|
||||
python get_model.py --model yolov8n
|
||||
|
||||
# Or using uv
|
||||
uv run get_model.py --model yolov8n
|
||||
|
||||
# List available models
|
||||
uv run get_model.py --list
|
||||
```
|
||||
|
||||
### 2. Run the model test
|
||||
|
||||
After building, you can run the test. The model is already compiled in:
|
||||
|
||||
```bash
|
||||
YOLO_MODEL=yolov8s cargo run --release
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
yolo/
|
||||
├── artifacts/ # Downloaded ONNX models and test data
|
||||
│ ├── yolov8n_opset16.onnx
|
||||
│ ├── yolov8n_test_data.pt
|
||||
│ └── ...
|
||||
├── src/
|
||||
│ └── main.rs # Test runner
|
||||
├── build.rs # Build script that generates model code
|
||||
├── get_model.py # Model download and preparation script
|
||||
└── Cargo.toml
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- All YOLO models (except v10) output shape `[1, 84, 8400]` for standard object detection
|
||||
- YOLOv10n has a different architecture with output shape `[1, 300, 6]` and uses TopK operator
|
||||
- The crate requires explicit model selection at build time (no default model)
|
||||
@@ -1,87 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
fn main() {
|
||||
// Supported models
|
||||
let supported_models = vec![
|
||||
"yolov5s", "yolov8n", "yolov8s", "yolov10n", "yolo11x", "yolo12x",
|
||||
];
|
||||
|
||||
// Get the model name from environment variable (required)
|
||||
let model_name = env::var("YOLO_MODEL").unwrap_or_else(|_| {
|
||||
eprintln!("Error: YOLO_MODEL environment variable is not set.");
|
||||
eprintln!();
|
||||
eprintln!("Please specify which YOLO model to build:");
|
||||
eprintln!(" YOLO_MODEL=yolov8n cargo build");
|
||||
eprintln!();
|
||||
eprintln!("Available models: {}", supported_models.join(", "));
|
||||
std::process::exit(1);
|
||||
});
|
||||
|
||||
if !supported_models.contains(&model_name.as_str()) {
|
||||
eprintln!(
|
||||
"Error: Unsupported model '{}'. Supported models: {:?}",
|
||||
model_name, supported_models
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let onnx_path = format!("artifacts/{}_opset16.onnx", model_name);
|
||||
let test_data_path = format!("artifacts/{}_test_data.pt", model_name);
|
||||
|
||||
// Tell Cargo to only rebuild if these files change
|
||||
println!("cargo:rerun-if-changed={}", onnx_path);
|
||||
println!("cargo:rerun-if-changed={}", test_data_path);
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rerun-if-env-changed=YOLO_MODEL");
|
||||
|
||||
// Check if the ONNX model file exists
|
||||
if !Path::new(&onnx_path).exists() {
|
||||
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
|
||||
eprintln!();
|
||||
eprintln!(
|
||||
"Please run the following command to download and prepare the {} model:",
|
||||
model_name
|
||||
);
|
||||
eprintln!(" python get_model.py --model {}", model_name);
|
||||
eprintln!();
|
||||
eprintln!("Or if you prefer using uv:");
|
||||
eprintln!(" uv run get_model.py --model {}", model_name);
|
||||
eprintln!();
|
||||
eprintln!("Available models: {}", supported_models.join(", "));
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Generate the model code from the ONNX file
|
||||
ModelGen::new()
|
||||
.input(&onnx_path)
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
|
||||
// Write the model name to a file so main.rs can access it
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let model_info_path = Path::new(&out_dir).join("model_info.rs");
|
||||
|
||||
// Generate the include path for the model
|
||||
let model_include = format!(
|
||||
"include!(concat!(env!(\"OUT_DIR\"), \"/model/{}_opset16.rs\"));",
|
||||
model_name
|
||||
);
|
||||
|
||||
fs::write(
|
||||
model_info_path,
|
||||
format!(
|
||||
r#"pub const MODEL_NAME: &str = "{}";
|
||||
pub const TEST_DATA_FILE: &str = "{}_test_data.pt";
|
||||
|
||||
// Include the generated model
|
||||
pub mod yolo_model {{
|
||||
{}
|
||||
}}"#,
|
||||
model_name, model_name, model_include
|
||||
),
|
||||
)
|
||||
.expect("Failed to write model info");
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
#!/usr/bin/env -S uv run --script
|
||||
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "onnx>=1.17.0",
|
||||
# "onnxruntime>=1.18.0",
|
||||
# "ultralytics>=8.3.0",
|
||||
# "numpy",
|
||||
# "pillow",
|
||||
# "torch",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import os
|
||||
import sys
|
||||
import onnx
|
||||
from onnx import shape_inference, version_converter
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
# Supported YOLO models configuration
|
||||
SUPPORTED_MODELS = {
|
||||
'yolov5s': {'download_name': 'yolov5s.pt', 'display_name': 'YOLOv5s'},
|
||||
'yolov8n': {'download_name': 'yolov8n.pt', 'display_name': 'YOLOv8n'},
|
||||
'yolov8s': {'download_name': 'yolov8s.pt', 'display_name': 'YOLOv8s'},
|
||||
'yolov10n': {'download_name': 'yolov10n.pt', 'display_name': 'YOLOv10n'},
|
||||
'yolo11x': {'download_name': 'yolo11x.pt', 'display_name': 'YOLO11x'},
|
||||
'yolo12x': {'download_name': 'yolo12x.pt', 'display_name': 'YOLO12x'},
|
||||
}
|
||||
|
||||
|
||||
def get_input_shape(model):
|
||||
"""Extract input shape from ONNX model."""
|
||||
input_info = model.graph.input[0]
|
||||
shape = []
|
||||
for dim in input_info.type.tensor_type.shape.dim:
|
||||
if dim.HasField('dim_value'):
|
||||
shape.append(dim.dim_value)
|
||||
else:
|
||||
shape.append(1) # Default to 1 for dynamic dimensions
|
||||
|
||||
# Ensure valid YOLO input shape
|
||||
if len(shape) != 4 or shape[2] == 0 or shape[2] > 2000:
|
||||
return [1, 3, 640, 640]
|
||||
return shape
|
||||
|
||||
|
||||
def download_and_convert_model(model_name, output_path):
|
||||
"""Download YOLO model and export to ONNX format."""
|
||||
from ultralytics import YOLO
|
||||
|
||||
model_config = SUPPORTED_MODELS[model_name]
|
||||
display_name = model_config['display_name']
|
||||
download_name = model_config['download_name']
|
||||
|
||||
print(f"Downloading {display_name} model...")
|
||||
model = YOLO(download_name)
|
||||
|
||||
print("Exporting to ONNX format...")
|
||||
model.export(format="onnx", simplify=True)
|
||||
|
||||
# Move exported file to artifacts
|
||||
base_name = download_name.replace('.pt', '')
|
||||
exported_file = Path(f"{base_name}.onnx")
|
||||
if exported_file.exists():
|
||||
exported_file.rename(output_path)
|
||||
|
||||
# Clean up PyTorch file
|
||||
pt_file = Path(download_name)
|
||||
if pt_file.exists():
|
||||
pt_file.unlink()
|
||||
|
||||
if not output_path.exists():
|
||||
raise FileNotFoundError(f"Failed to create ONNX file at {output_path}")
|
||||
|
||||
|
||||
def process_model(input_path, output_path, target_opset=16):
|
||||
"""Load, upgrade opset, and apply shape inference to model."""
|
||||
print(f"Loading model from {input_path}...")
|
||||
model = onnx.load(input_path)
|
||||
|
||||
# Check and upgrade opset if needed
|
||||
current_opset = model.opset_import[0].version
|
||||
if current_opset < target_opset:
|
||||
print(f"Upgrading opset from {current_opset} to {target_opset}...")
|
||||
model = version_converter.convert_version(model, target_opset)
|
||||
|
||||
# Apply shape inference
|
||||
print("Applying shape inference...")
|
||||
model = shape_inference.infer_shapes(model)
|
||||
|
||||
# Save processed model
|
||||
onnx.save(model, output_path)
|
||||
print(f"✓ Processed model saved to: {output_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def generate_test_data(model_path, output_path, model_name):
|
||||
"""Generate test input/output data and save as PyTorch tensors."""
|
||||
import torch
|
||||
import onnxruntime as ort
|
||||
|
||||
print("\nGenerating test data...")
|
||||
|
||||
# Load model to get input shape
|
||||
model = onnx.load(model_path)
|
||||
input_shape = get_input_shape(model)
|
||||
print(f" Input shape: {input_shape}")
|
||||
|
||||
# Create reproducible test input
|
||||
np.random.seed(42)
|
||||
test_input = np.random.rand(*input_shape).astype(np.float32)
|
||||
|
||||
# Run inference to get output
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
outputs = session.run(None, {input_name: test_input})
|
||||
|
||||
# Save as PyTorch tensors
|
||||
test_data = {
|
||||
'input': torch.from_numpy(test_input),
|
||||
'output': torch.from_numpy(outputs[0])
|
||||
}
|
||||
|
||||
torch.save(test_data, output_path)
|
||||
|
||||
print(f" ✓ Test data saved to: {output_path}")
|
||||
print(f" Input shape: {test_input.shape}, Output shape: {outputs[0].shape}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='YOLO Model Preparation Tool')
|
||||
parser.add_argument('--model', type=str, default='yolov8n',
|
||||
choices=list(SUPPORTED_MODELS.keys()),
|
||||
help=f'YOLO model to download and prepare (default: yolov8n). Choices: {", ".join(SUPPORTED_MODELS.keys())}')
|
||||
parser.add_argument('--list', action='store_true',
|
||||
help='List all supported models')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list:
|
||||
print("Supported YOLO models:")
|
||||
for model_id, config in SUPPORTED_MODELS.items():
|
||||
print(f" - {model_id:10s} ({config['display_name']})")
|
||||
return
|
||||
|
||||
model_name = args.model
|
||||
display_name = SUPPORTED_MODELS[model_name]['display_name']
|
||||
|
||||
print("=" * 60)
|
||||
print(f"{display_name} Model Preparation Tool")
|
||||
print("=" * 60)
|
||||
|
||||
# Setup paths
|
||||
artifacts_dir = Path("artifacts")
|
||||
artifacts_dir.mkdir(exist_ok=True)
|
||||
|
||||
original_path = artifacts_dir / f"{model_name}.onnx"
|
||||
processed_path = artifacts_dir / f"{model_name}_opset16.onnx"
|
||||
test_data_path = artifacts_dir / f"{model_name}_test_data.pt"
|
||||
|
||||
# Check if we already have everything
|
||||
if processed_path.exists() and test_data_path.exists():
|
||||
print(f"\n✓ All files already exist for {display_name}:")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print("\nNothing to do!")
|
||||
return
|
||||
|
||||
# Download and convert if needed
|
||||
if not original_path.exists() and not processed_path.exists():
|
||||
print(f"\nStep 1: Downloading and converting {display_name} model...")
|
||||
download_and_convert_model(model_name, original_path)
|
||||
|
||||
# Process model if needed
|
||||
if not processed_path.exists():
|
||||
print("\nStep 2: Processing model...")
|
||||
process_model(original_path, processed_path, target_opset=16)
|
||||
|
||||
# Clean up original if we have the processed version
|
||||
if original_path.exists():
|
||||
original_path.unlink()
|
||||
|
||||
# Generate test data if needed
|
||||
if not test_data_path.exists():
|
||||
print("\nStep 3: Generating test data...")
|
||||
generate_test_data(processed_path, test_data_path, model_name)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"✓ {display_name} model preparation completed!")
|
||||
print(f" Model: {processed_path}")
|
||||
print(f" Test data: {test_data_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠ Operation cancelled by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
@@ -1,196 +0,0 @@
|
||||
extern crate alloc;
|
||||
|
||||
use burn::module::{Initializer, Param};
|
||||
use burn::prelude::*;
|
||||
|
||||
use burn_store::{ModuleSnapshot, PytorchStore};
|
||||
use std::path::Path;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type MyBackend = burn::backend::Wgpu;
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub type MyBackend = burn::backend::NdArray<f32>;
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub type MyBackend = burn::backend::LibTorch<f32>;
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub type MyBackend = burn::backend::Metal;
|
||||
|
||||
// Import model info generated by build.rs (includes the yolo_model module)
|
||||
include!(concat!(env!("OUT_DIR"), "/model_info.rs"));
|
||||
|
||||
// Use the yolo_model module from model_info.rs
|
||||
use yolo_model::Model;
|
||||
|
||||
#[derive(Debug, Module)]
|
||||
struct TestData<B: Backend> {
|
||||
input: Param<Tensor<B, 4>>,
|
||||
output: Param<Tensor<B, 3>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TestData<B> {
|
||||
fn new(device: &B::Device) -> Self {
|
||||
// YOLO: input 640x640, output [1, 84, 8400]
|
||||
Self {
|
||||
input: Initializer::Zeros.init([1, 3, 640, 640], device),
|
||||
output: Initializer::Zeros.init([1, 84, 8400], device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_model_display_name(model_name: &str) -> &str {
|
||||
match model_name {
|
||||
"yolov5s" => "YOLOv5s",
|
||||
"yolov8n" => "YOLOv8n",
|
||||
"yolov8s" => "YOLOv8s",
|
||||
"yolov10n" => "YOLOv10n",
|
||||
"yolo11x" => "YOLO11x",
|
||||
"yolo12x" => "YOLO12x",
|
||||
_ => model_name,
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// MODEL_NAME is set at build time from YOLO_MODEL env var
|
||||
let model_name = MODEL_NAME;
|
||||
let display_name = get_model_display_name(model_name);
|
||||
|
||||
println!("========================================");
|
||||
println!("{} Burn Model Test", display_name);
|
||||
println!("========================================\n");
|
||||
|
||||
// Check if artifacts exist
|
||||
let artifacts_dir = Path::new("artifacts");
|
||||
if !artifacts_dir.exists() {
|
||||
eprintln!("Error: artifacts directory not found!");
|
||||
eprintln!("Please run get_model.py first to download the model and test data.");
|
||||
eprintln!("Example: uv run get_model.py --model {}", model_name);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Check if model files exist for this specific model
|
||||
let model_file = artifacts_dir.join(format!("{}_opset16.onnx", model_name));
|
||||
let test_data_file = artifacts_dir.join(format!("{}_test_data.pt", model_name));
|
||||
|
||||
if !model_file.exists() || !test_data_file.exists() {
|
||||
eprintln!("Error: Model files not found for {}!", display_name);
|
||||
eprintln!("Please run: uv run get_model.py --model {}", model_name);
|
||||
eprintln!();
|
||||
eprintln!("Available models:");
|
||||
eprintln!(" - yolov5s");
|
||||
eprintln!(" - yolov8n");
|
||||
eprintln!(" - yolov8s");
|
||||
eprintln!(" - yolov10n");
|
||||
eprintln!(" - yolo11x");
|
||||
eprintln!(" - yolo12x");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Initialize the model (without weights for now)
|
||||
println!("Initializing {} model...", display_name);
|
||||
let start = Instant::now();
|
||||
let device = Default::default();
|
||||
let model: Model<MyBackend> = Model::default();
|
||||
let init_time = start.elapsed();
|
||||
println!(" Model initialized in {:.2?}", init_time);
|
||||
|
||||
// Save model structure to file
|
||||
let model_txt_path = artifacts_dir.join(format!("{}_model.txt", model_name));
|
||||
println!(
|
||||
"\nSaving model structure to {}...",
|
||||
model_txt_path.display()
|
||||
);
|
||||
let model_str = format!("{}", model);
|
||||
std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file");
|
||||
println!(" Model structure saved");
|
||||
|
||||
// Load test data from PyTorch file
|
||||
println!("\nLoading test data from {}...", test_data_file.display());
|
||||
let start = Instant::now();
|
||||
let mut test_data = TestData::<MyBackend>::new(&device);
|
||||
let mut store = PytorchStore::from_file(&test_data_file);
|
||||
test_data.load_from(&mut store).expect("Failed to load test data");
|
||||
let load_time = start.elapsed();
|
||||
println!(" Data loaded in {:.2?}", load_time);
|
||||
|
||||
// Get the input tensor from test data
|
||||
let input = test_data.input.val();
|
||||
let input_shape = input.shape();
|
||||
println!(" Loaded input tensor with shape: {:?}", input_shape.dims);
|
||||
|
||||
// Get the reference output from test data
|
||||
let reference_output = test_data.output.val();
|
||||
let reference_shape = reference_output.shape();
|
||||
println!(
|
||||
" Loaded reference output with shape: {:?}",
|
||||
reference_shape.dims
|
||||
);
|
||||
|
||||
// Run inference with the loaded input
|
||||
println!("\nRunning model inference with test input...");
|
||||
let start = Instant::now();
|
||||
let output = model.forward(input);
|
||||
let inference_time = start.elapsed();
|
||||
println!(" Inference completed in {:.2?}", inference_time);
|
||||
|
||||
// Display output shape
|
||||
let shape = output.shape();
|
||||
println!("\n Model output shape: {:?}", shape.dims);
|
||||
|
||||
// Verify expected output shape (most YOLO models use [1, 84, 8400])
|
||||
let expected_shape = [1, 84, 8400];
|
||||
if shape.dims == expected_shape {
|
||||
println!(" ✓ Output shape matches expected: {:?}", expected_shape);
|
||||
} else {
|
||||
println!(
|
||||
" ⚠ Note: Shape is {:?} (expected {:?} for most YOLO models)",
|
||||
shape.dims, expected_shape
|
||||
);
|
||||
}
|
||||
|
||||
// Compare outputs
|
||||
println!("\nComparing model output with reference data...");
|
||||
|
||||
// Check if outputs are close
|
||||
if output
|
||||
.clone()
|
||||
.all_close(reference_output.clone(), Some(1e-4), Some(1e-4))
|
||||
{
|
||||
println!(" ✓ Model output matches reference data within tolerance (1e-4)!");
|
||||
} else {
|
||||
println!(" ⚠ Model output differs from reference data!");
|
||||
|
||||
// Calculate and display the difference statistics
|
||||
let diff = output.clone() - reference_output.clone();
|
||||
let abs_diff = diff.abs();
|
||||
let max_diff = abs_diff.clone().max().into_scalar();
|
||||
let mean_diff = abs_diff.mean().into_scalar();
|
||||
|
||||
println!(" Maximum absolute difference: {:.6}", max_diff);
|
||||
println!(" Mean absolute difference: {:.6}", mean_diff);
|
||||
|
||||
// Show some sample values for debugging
|
||||
println!("\n Sample values comparison (first 5 elements):");
|
||||
let output_flat = output.clone().flatten::<1>(0, 2);
|
||||
let reference_flat = reference_output.clone().flatten::<1>(0, 2);
|
||||
|
||||
for i in 0..5.min(output_flat.dims()[0]) {
|
||||
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
|
||||
println!(
|
||||
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
|
||||
i,
|
||||
model_val,
|
||||
ref_val,
|
||||
(model_val - ref_val).abs()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n========================================");
|
||||
println!("Model test completed!");
|
||||
println!("========================================");
|
||||
}
|
||||
10
crates/burn-onnx/onnx-tests/.gitignore
vendored
10
crates/burn-onnx/onnx-tests/.gitignore
vendored
@@ -1,10 +0,0 @@
|
||||
# python generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# venv
|
||||
.venv
|
||||
@@ -1,22 +0,0 @@
|
||||
[package]
|
||||
name = "onnx-tests"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["test-ndarray"]
|
||||
test-wgpu = ["burn/wgpu"]
|
||||
test-ndarray = ["burn/ndarray"]
|
||||
test-tch = ["burn/tch"]
|
||||
test-metal = ["burn/metal"]
|
||||
test-candle = ["burn/candle"]
|
||||
|
||||
[dev-dependencies]
|
||||
burn = { path = "../../burn" }
|
||||
burn-store = { path = "../../burn-store", features = ["burnpack"] }
|
||||
serde = { workspace = true }
|
||||
float-cmp = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
burn-onnx = { path = "../" }
|
||||
@@ -1,312 +0,0 @@
|
||||
# ONNX Tests
|
||||
|
||||
This crate contains ONNX models used for testing the conversion process from ONNX to Burn source
|
||||
code through the `burn-onnx` crate. These tests are designed as end-to-end tests, ensuring that
|
||||
ONNX models are accurately converted into Burn source code that compiles without errors and produces
|
||||
the same outputs as the original ONNX model.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
- `tests/<op_name>/`: Each operator or model has its own directory
|
||||
- `tests/<op_name>/<op_name>.py`: Python script that generates the ONNX model
|
||||
- `tests/<op_name>/<op_name>.onnx`: Generated ONNX model
|
||||
- `tests/<op_name>/mod.rs`: Test implementation for the specific operator
|
||||
- `tests/test_mod.rs`: Main test file that integrates all operator tests
|
||||
- `build.rs`: Build script that generates ONNX models before running tests
|
||||
|
||||
## Setting Up Your Python Environment
|
||||
|
||||
### Using uv (Recommended)
|
||||
|
||||
You can use [`uv`](https://docs.astral.sh/uv/) to set up a Python environment with the necessary
|
||||
dependencies:
|
||||
|
||||
```sh
|
||||
cd crates/burn-onnx/onnx-tests
|
||||
uv sync # or uv sync -f
|
||||
```
|
||||
|
||||
This will create a `.venv` directory with all the required dependencies.
|
||||
|
||||
### Manual Setup
|
||||
|
||||
If you prefer to set up manually, you need to install the following packages:
|
||||
|
||||
```sh
|
||||
pip install onnx==1.15.0 torch==2.1.1
|
||||
```
|
||||
|
||||
Additional dependencies are specified in `requirements.lock`.
|
||||
|
||||
## Creating a Test for a New Operator
|
||||
|
||||
There are two main approaches to generating ONNX files for testing:
|
||||
|
||||
1. **Exporting a model from PyTorch** (most common)
|
||||
2. **Constructing an ONNX graph directly** (for specific cases)
|
||||
|
||||
### 1. Create the Python Script
|
||||
|
||||
Create a new directory and Python script:
|
||||
|
||||
```sh
|
||||
mkdir -p tests/my_new_op
|
||||
touch tests/my_new_op/my_new_op.py
|
||||
```
|
||||
|
||||
#### Approach 1: Exporting a PyTorch Model to ONNX
|
||||
|
||||
Your script should:
|
||||
|
||||
- Import the necessary PyTorch modules
|
||||
- Define a model that uses your operator
|
||||
- Generate test inputs
|
||||
- Export the model to ONNX format
|
||||
- Run the model on test inputs and print the output
|
||||
|
||||
Example structure:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.onnx
|
||||
|
||||
# Define a simple model that uses your operator
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# ...
|
||||
|
||||
def forward(self, x):
|
||||
# Use your operator here
|
||||
return my_operation(x)
|
||||
|
||||
# Create an instance of the model
|
||||
model = MyModel()
|
||||
|
||||
# Generate test input
|
||||
input_tensor = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# Export the model to ONNX
|
||||
torch.onnx.export(
|
||||
model,
|
||||
input_tensor,
|
||||
"tests/my_new_op/my_new_op.onnx",
|
||||
opset_version=16,
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
do_constant_folding=False # Set to False if you want to preserve specific operators
|
||||
)
|
||||
|
||||
# Run the model with the test input and print output for test verification
|
||||
output = model(input_tensor)
|
||||
print("Input:", input_tensor)
|
||||
print("Output:", output)
|
||||
```
|
||||
|
||||
#### Approach 2: Constructing an ONNX Graph Directly
|
||||
|
||||
For some test cases, you may want to construct the ONNX graph directly using the ONNX Python API.
|
||||
This is particularly useful when:
|
||||
|
||||
- You need precise control over operator attributes
|
||||
- You're testing operators that are difficult to trigger through PyTorch models
|
||||
- You want to test specific graph structures
|
||||
|
||||
Example (see `tests/gather/gather_1d_idx.py` for a complete example):
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
# Create inputs
|
||||
data = np.random.randn(5, 5, 5).astype(np.float32)
|
||||
indices = np.array([0, 2, 4], dtype=np.int64)
|
||||
|
||||
# Create node
|
||||
node = helper.make_node(
|
||||
"Gather",
|
||||
inputs=["data", "indices"],
|
||||
outputs=["output"],
|
||||
axis=1
|
||||
)
|
||||
|
||||
# Create input tensors
|
||||
data_tensor = helper.make_tensor_value_info("data", TensorProto.FLOAT, data.shape)
|
||||
indices_tensor = helper.make_tensor_value_info("indices", TensorProto.INT64, indices.shape)
|
||||
|
||||
# Create output tensor
|
||||
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [5, 3, 5])
|
||||
|
||||
# Create graph and model
|
||||
graph = helper.make_graph(
|
||||
[node],
|
||||
"gather-model",
|
||||
[data_tensor, indices_tensor],
|
||||
[output_tensor],
|
||||
initializer=[]
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, "tests/my_new_op/my_new_op.onnx")
|
||||
|
||||
# For test verification, print input and expected output
|
||||
print("Data:", data)
|
||||
print("Indices:", indices)
|
||||
print("Expected output:", np.take(data, indices, axis=1))
|
||||
```
|
||||
|
||||
### 2. Add the Build Step
|
||||
|
||||
Update `build.rs` to include your new model.
|
||||
|
||||
### 3. Create a mod.rs Test File
|
||||
|
||||
Create a test module file in your operator directory:
|
||||
|
||||
```sh
|
||||
touch tests/my_new_op/mod.rs
|
||||
```
|
||||
|
||||
Implement the test for your operator in this file:
|
||||
|
||||
```rust
|
||||
use super::test_record_type::TestRecordType;
|
||||
use burn_onnx::OnnxModel;
|
||||
|
||||
#[test]
|
||||
fn test_my_new_op() {
|
||||
let model = OnnxModel::read("tests/my_new_op/my_new_op.onnx").unwrap();
|
||||
let record = model.into_record::<TestRecordType>();
|
||||
// Implement test logic and assertions here
|
||||
}
|
||||
```
|
||||
|
||||
Your test will be automatically included in the main test suite through `tests/test_mod.rs`.
|
||||
|
||||
## Best Practices for ONNX Testing
|
||||
|
||||
### Model Generation
|
||||
|
||||
1. **Keep Models Simple**: Focus on testing a single operator or a small group of related operators.
|
||||
|
||||
2. **Control Randomness**: Use fixed seeds in your Python scripts to ensure reproducible results:
|
||||
|
||||
```python
|
||||
torch.manual_seed(42)
|
||||
```
|
||||
|
||||
3. **Print Test Values**: Always print your input and output tensors in the Python script for
|
||||
reference.
|
||||
|
||||
4. **Verify Operators**: Use [Netron](https://github.com/lutzroeder/netron) to verify your ONNX
|
||||
model contains the expected operators.
|
||||
|
||||
5. **Handle Constant Folding**: If PyTorch is optimizing away your operators, use:
|
||||
```python
|
||||
torch.onnx.export(..., do_constant_folding=False)
|
||||
```
|
||||
|
||||
### Test Implementation
|
||||
|
||||
1. **Test Multiple Cases**: Include tests for different input shapes, data types, and parameter
|
||||
combinations.
|
||||
|
||||
2. **Edge Cases**: Test edge cases like empty tensors, single-element tensors, or very large
|
||||
tensors.
|
||||
|
||||
3. **Parameter Variations**: If your operator has configurable parameters, test different parameter
|
||||
values.
|
||||
|
||||
4. **Numerical Precision**: Use appropriate tolerance levels based on operation sensitivity.
|
||||
|
||||
5. **Error Cases**: Test that invalid inputs are properly handled and appropriate errors are raised.
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Default Backend
|
||||
|
||||
Run all tests with:
|
||||
|
||||
```sh
|
||||
cargo test
|
||||
```
|
||||
|
||||
This command runs all tests using the default backend: `burn::backend::NdArray<f32>`.
|
||||
|
||||
### Testing with Different Backends
|
||||
|
||||
You can test with different Burn backends by using feature flags:
|
||||
|
||||
#### WGPU Backend
|
||||
|
||||
```sh
|
||||
cargo test --features test-wgpu
|
||||
```
|
||||
|
||||
Uses `burn::backend::Wgpu` for GPU-accelerated computation.
|
||||
|
||||
#### LibTorch Backend
|
||||
|
||||
```sh
|
||||
cargo test --features test-tch
|
||||
```
|
||||
|
||||
Uses `burn::backend::LibTorch<f32>` for Torch backend integration.
|
||||
|
||||
#### NdArray Backend (Explicit)
|
||||
|
||||
```sh
|
||||
cargo test --features test-ndarray
|
||||
```
|
||||
|
||||
Explicitly uses `burn::backend::NdArray<f32>` (same as default).
|
||||
|
||||
### Running Specific Tests
|
||||
|
||||
Run tests for a specific operator:
|
||||
|
||||
```sh
|
||||
cargo test --test test_mod my_new_op::test_my_new_op
|
||||
```
|
||||
|
||||
Run a specific test with a selected backend:
|
||||
|
||||
```sh
|
||||
cargo test --test test_mod my_new_op::test_my_new_op --features test-wgpu
|
||||
```
|
||||
|
||||
### Supported Backends
|
||||
|
||||
- `burn::backend::NdArray<f32>` (default) - CPU-based computation using ndarray
|
||||
- `burn::backend::Wgpu` - GPU-accelerated computation using WebGPU
|
||||
- `burn::backend::LibTorch<f32>` - Torch backend integration
|
||||
|
||||
**Note:** Only one backend feature should be enabled at a time. The backend selection uses
|
||||
conditional compilation with the following priority:
|
||||
|
||||
1. `test-wgpu` (highest priority)
|
||||
2. `test-tch`
|
||||
3. `test-ndarray` (default when no other backend is selected)
|
||||
|
||||
## Debugging Failed Tests
|
||||
|
||||
If a test fails, you can:
|
||||
|
||||
1. **Inspect ONNX Model**: Use Netron to visualize the model structure.
|
||||
|
||||
2. **Check Intermediate Values**: Add print statements in your Python script to see intermediate
|
||||
tensor values.
|
||||
|
||||
3. **Generate Rust Code**: Use the `burn-onnx` CLI to generate Rust code and inspect it:
|
||||
|
||||
```sh
|
||||
cargo run -p burn-onnx -- tests/my_new_op/my_new_op.onnx ./out
|
||||
```
|
||||
|
||||
4. **Trace Through Conversion**: Add debug logging in your implementation to see where things might
|
||||
be going wrong.
|
||||
|
||||
5. **Numerical Issues**: If values are close but not equal, it might be a numerical precision issue.
|
||||
Try adjusting tolerance.
|
||||
@@ -1,387 +0,0 @@
|
||||
use burn_onnx::ModelGen;
|
||||
|
||||
fn main() {
|
||||
// Re-run this build script if the onnx-tests directory changes.
|
||||
println!("cargo:rerun-if-changed=tests");
|
||||
|
||||
// Add onnx models.
|
||||
// All models are now saved in burnpack format (.bpk files)
|
||||
ModelGen::new()
|
||||
.input("tests/abs/abs.onnx")
|
||||
.input("tests/add/add.onnx")
|
||||
.input("tests/add/add_shape.onnx")
|
||||
.input("tests/add/add_broadcast.onnx")
|
||||
.input("tests/initializer_to_const/initializer_to_const.onnx")
|
||||
.input("tests/and/and.onnx")
|
||||
.input("tests/and/and_broadcast.onnx")
|
||||
.input("tests/and/and_scalar.onnx")
|
||||
.input("tests/add/add_int.onnx")
|
||||
.input("tests/add/add_shape_tensor.onnx")
|
||||
.input("tests/add/add_argmax_with_shape.onnx")
|
||||
.input("tests/argmax/argmax.onnx")
|
||||
.input("tests/argmax/argmax_both_keepdims.onnx")
|
||||
.input("tests/argmax/argmax_1d.onnx")
|
||||
.input("tests/argmin/argmin.onnx")
|
||||
.input("tests/argmin/argmin_both_keepdims.onnx")
|
||||
.input("tests/argmin/argmin_1d.onnx")
|
||||
.input("tests/attention/attention_4d.onnx")
|
||||
.input("tests/attention/attention_3d.onnx")
|
||||
.input("tests/attention/attention_attn_mask_bool.onnx")
|
||||
.input("tests/attention/attention_attn_mask_int.onnx")
|
||||
.input("tests/attention/attention_attn_mask_float.onnx")
|
||||
.input("tests/attention/attention_softcap.onnx")
|
||||
.input("tests/attention/attention_cache.onnx")
|
||||
.input("tests/attention/attention_custom_scale.onnx")
|
||||
.input("tests/attention/attention_is_causal.onnx")
|
||||
.input("tests/attention/attention_qk_output_0.onnx")
|
||||
.input("tests/attention/attention_qk_output_1.onnx")
|
||||
.input("tests/attention/attention_qk_output_2.onnx")
|
||||
.input("tests/attention/attention_qk_output_3.onnx")
|
||||
.input("tests/avg_pool1d/avg_pool1d.onnx")
|
||||
.input("tests/avg_pool1d_ceil_mode/avg_pool1d_ceil_mode.onnx")
|
||||
.input("tests/avg_pool2d/avg_pool2d.onnx")
|
||||
.input("tests/avg_pool2d_ceil_mode/avg_pool2d_ceil_mode.onnx")
|
||||
.input("tests/batch_norm/batch_norm.onnx")
|
||||
.input("tests/bitshift/bitshift_left.onnx")
|
||||
.input("tests/bitshift/bitshift_left_scalar.onnx")
|
||||
.input("tests/bitshift/scalar_bitshift_left.onnx")
|
||||
.input("tests/bitshift/scalar_bitshift_left_scalar.onnx")
|
||||
.input("tests/bitshift/bitshift_right.onnx")
|
||||
.input("tests/bitshift/bitshift_right_scalar.onnx")
|
||||
.input("tests/bitshift/scalar_bitshift_right.onnx")
|
||||
.input("tests/bitshift/scalar_bitshift_right_scalar.onnx")
|
||||
.input("tests/bitwise_and/bitwise_and.onnx")
|
||||
.input("tests/bitwise_and/bitwise_and_scalar.onnx")
|
||||
.input("tests/bitwise_and/scalar_bitwise_and.onnx")
|
||||
.input("tests/bitwise_and/scalar_bitwise_and_scalar.onnx")
|
||||
.input("tests/bitwise_not/bitwise_not.onnx")
|
||||
.input("tests/bitwise_or/bitwise_or.onnx")
|
||||
.input("tests/bitwise_or/bitwise_or_scalar.onnx")
|
||||
.input("tests/bitwise_or/scalar_bitwise_or.onnx")
|
||||
.input("tests/bitwise_or/scalar_bitwise_or_scalar.onnx")
|
||||
.input("tests/bitwise_xor/bitwise_xor.onnx")
|
||||
.input("tests/bitwise_xor/bitwise_xor_scalar.onnx")
|
||||
.input("tests/bitwise_xor/scalar_bitwise_xor.onnx")
|
||||
.input("tests/bitwise_xor/scalar_bitwise_xor_scalar.onnx")
|
||||
.input("tests/bernoulli/bernoulli.onnx")
|
||||
.input("tests/cast/cast.onnx")
|
||||
.input("tests/cast/cast_shape.onnx")
|
||||
.input("tests/cast/cast_shape_to_float.onnx")
|
||||
.input("tests/cast/cast_shape_to_bool.onnx")
|
||||
.input("tests/ceil/ceil.onnx")
|
||||
.input("tests/clip/clip.onnx")
|
||||
.input("tests/concat/concat.onnx")
|
||||
.input("tests/concat/concat_shape.onnx")
|
||||
.input("tests/concat/concat_shape_with_constant.onnx")
|
||||
.input("tests/concat/concat_mixed_single_element.onnx")
|
||||
.input("tests/concat/concat_mixed_three_elements.onnx")
|
||||
.input("tests/concat/concat_multiple_mixed.onnx")
|
||||
.input("tests/concat/concat_with_constants.onnx")
|
||||
.input("tests/constant/constant_f32.onnx")
|
||||
.input("tests/constant/constant_f64.onnx")
|
||||
.input("tests/constant/constant_i32.onnx")
|
||||
.input("tests/constant/constant_i64.onnx")
|
||||
.input("tests/constant/constant_bool.onnx")
|
||||
.input("tests/constant/constant_shape.onnx")
|
||||
.input("tests/constant/constant_tensor_f32.onnx")
|
||||
.input("tests/constant/constant_tensor_i32.onnx")
|
||||
.input("tests/constant/constant_tensor_bool.onnx")
|
||||
.input("tests/constant/constant_empty_tensor_f32.onnx")
|
||||
.input("tests/constant/rank_inference_propagation.onnx")
|
||||
.input("tests/constant/shape_binary_ops_with_constant.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape_full_like.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape_scalar.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape_scalar_custom_value.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape_tensor.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape_shape_optimization.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape_with_constant_input.onnx")
|
||||
.input("tests/constant_lifting_multiple/constant_lifting_multiple.onnx")
|
||||
.input("tests/constant_lifting_multiple/constant_reused.onnx")
|
||||
.input("tests/conv1d/conv1d.onnx")
|
||||
.input("tests/conv2d/conv2d.onnx")
|
||||
.input("tests/conv3d/conv3d.onnx")
|
||||
.input("tests/conv_transpose1d/conv_transpose1d.onnx")
|
||||
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
|
||||
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
|
||||
.input("tests/cos/cos.onnx")
|
||||
.input("tests/cosh/cosh.onnx")
|
||||
.input("tests/cumsum/cumsum.onnx")
|
||||
.input("tests/cumsum/cumsum_exclusive.onnx")
|
||||
.input("tests/cumsum/cumsum_reverse.onnx")
|
||||
.input("tests/cumsum/cumsum_exclusive_reverse.onnx")
|
||||
.input("tests/cumsum/cumsum_2d.onnx")
|
||||
.input("tests/cumsum/cumsum_runtime_axis.onnx")
|
||||
.input("tests/cumsum/cumsum_single_element.onnx")
|
||||
.input("tests/cumsum/cumsum_exclusive_single.onnx")
|
||||
.input("tests/depth_to_space/depth_to_space_dcr.onnx")
|
||||
.input("tests/depth_to_space/depth_to_space_crd.onnx")
|
||||
.input("tests/div/div.onnx")
|
||||
.input("tests/div/div_shape.onnx")
|
||||
.input("tests/div/div_shape_tensor.onnx")
|
||||
.input("tests/div/div_broadcast.onnx")
|
||||
.input("tests/mod/modulo.onnx")
|
||||
.input("tests/mod/mod_scalar.onnx")
|
||||
.input("tests/mod/mod_remainder.onnx")
|
||||
.input("tests/mod/mod_fmod.onnx")
|
||||
.input("tests/mod/mod_broadcast_fixed.onnx")
|
||||
.input("tests/mod/mod_broadcast_remainder_fixed.onnx")
|
||||
.input("tests/dropout/dropout.onnx")
|
||||
.input("tests/empty_graph/empty_graph_scalar.onnx")
|
||||
.input("tests/empty_graph/empty_graph_scalar_int.onnx")
|
||||
.input("tests/empty_graph/empty_graph_shape.onnx")
|
||||
.input("tests/empty_graph/empty_graph_tensor.onnx")
|
||||
.input("tests/empty_graph/empty_graph_multiple.onnx")
|
||||
.input("tests/equal/equal.onnx")
|
||||
.input("tests/equal/equal_shape.onnx")
|
||||
.input("tests/equal/equal_two_shapes.onnx")
|
||||
.input("tests/erf/erf.onnx")
|
||||
.input("tests/exp/exp.onnx")
|
||||
.input("tests/expand/expand.onnx")
|
||||
.input("tests/expand/expand_tensor.onnx")
|
||||
.input("tests/expand/expand_shape.onnx")
|
||||
.input("tests/expand/expand_with_where_shape.onnx")
|
||||
.input("tests/expand/expand_max_semantics.onnx")
|
||||
.input("tests/eye_like/eye_like.onnx")
|
||||
.input("tests/eye_like/eye_like_k1.onnx")
|
||||
.input("tests/eye_like/eye_like_int.onnx")
|
||||
.input("tests/eye_like/eye_like_k_minus1.onnx")
|
||||
.input("tests/eye_like/eye_like_float64.onnx")
|
||||
.input("tests/eye_like/eye_like_int32.onnx")
|
||||
.input("tests/eye_like/eye_like_bool.onnx")
|
||||
.input("tests/eye_like/eye_like_large_k.onnx")
|
||||
.input("tests/eye_like/eye_like_1x1.onnx")
|
||||
.input("tests/eye_like/eye_like_wide.onnx")
|
||||
.input("tests/eye_like/eye_like_neg_large_k.onnx")
|
||||
.input("tests/nonzero/nonzero_float32.onnx")
|
||||
.input("tests/nonzero/nonzero_int64.onnx")
|
||||
.input("tests/nonzero/nonzero_bool.onnx")
|
||||
.input("tests/nonzero/nonzero_1d.onnx")
|
||||
.input("tests/nonzero/nonzero_3d.onnx")
|
||||
.input("tests/nonzero/nonzero_empty.onnx")
|
||||
.input("tests/flatten/flatten.onnx")
|
||||
.input("tests/flatten/flatten_2d.onnx")
|
||||
.input("tests/floor/floor.onnx")
|
||||
.input("tests/gather/gather_1d_idx.onnx")
|
||||
.input("tests/gather/gather_2d_idx.onnx")
|
||||
.input("tests/gather/gather_scalar.onnx")
|
||||
.input("tests/gather/gather_constant_2d_indices.onnx")
|
||||
.input("tests/gather/gather_static_shape_indices.onnx")
|
||||
.input("tests/gather/gather_shape.onnx")
|
||||
.input("tests/gather/gather_with_shape_indices.onnx")
|
||||
.input("tests/gather/gather_scalar_out.onnx")
|
||||
.input("tests/gather/gather_scalar_input.onnx")
|
||||
.input("tests/gather/gather_negative_idx.onnx")
|
||||
.input("tests/gather_elements/gather_elements.onnx")
|
||||
.input("tests/gelu/gelu.onnx")
|
||||
.input("tests/gemm/gemm.onnx")
|
||||
.input("tests/gemm/gemm_non_unit_alpha_beta.onnx")
|
||||
.input("tests/gemm/gemm_no_c.onnx")
|
||||
.input("tests/global_avr_pool/global_avr_pool.onnx")
|
||||
.input("tests/graph_multiple_output_tracking/graph_multiple_output_tracking.onnx")
|
||||
.input("tests/greater/greater.onnx")
|
||||
.input("tests/greater/greater_scalar.onnx")
|
||||
.input("tests/greater/greater_broadcast.onnx")
|
||||
.input("tests/greater_or_equal/greater_or_equal.onnx")
|
||||
.input("tests/greater_or_equal/greater_or_equal_scalar.onnx")
|
||||
.input("tests/greater_or_equal/greater_or_equal_broadcast.onnx")
|
||||
.input("tests/grid_sample/grid_sample.onnx")
|
||||
.input("tests/grid_sample/grid_sample_nearest.onnx")
|
||||
.input("tests/group_norm/group_norm.onnx")
|
||||
.input("tests/hard_sigmoid/hard_sigmoid.onnx")
|
||||
.input("tests/hard_swish/hard_swish.onnx")
|
||||
.input("tests/identity/identity_constant.onnx")
|
||||
.input("tests/identity/identity_passthrough.onnx")
|
||||
.input("tests/identity/identity_chain.onnx")
|
||||
.input("tests/identity/identity_only.onnx")
|
||||
.input("tests/instance_norm1d/instance_norm1d.onnx")
|
||||
.input("tests/instance_norm2d/instance_norm2d.onnx")
|
||||
.input("tests/instance_norm3d/instance_norm3d.onnx")
|
||||
.input("tests/is_inf/is_inf.onnx")
|
||||
.input("tests/is_inf/is_inf_scalar.onnx")
|
||||
.input("tests/is_inf/is_inf_neg_only.onnx")
|
||||
.input("tests/is_inf/is_inf_pos_only.onnx")
|
||||
.input("tests/is_inf/is_inf_none.onnx")
|
||||
.input("tests/is_nan/is_nan.onnx")
|
||||
.input("tests/is_nan/is_nan_scalar.onnx")
|
||||
.input("tests/layer_norm/layer_norm.onnx")
|
||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||
.input("tests/less/less.onnx")
|
||||
.input("tests/less/less_scalar.onnx")
|
||||
.input("tests/less/less_broadcast.onnx")
|
||||
.input("tests/less_or_equal/less_or_equal.onnx")
|
||||
.input("tests/less_or_equal/less_or_equal_scalar.onnx")
|
||||
.input("tests/less_or_equal/less_or_equal_broadcast.onnx")
|
||||
.input("tests/linear/linear.onnx")
|
||||
.input("tests/log/log.onnx")
|
||||
.input("tests/lstm/lstm.onnx")
|
||||
.input("tests/lstm/lstm_bidirectional.onnx")
|
||||
.input("tests/lstm/lstm_reverse.onnx")
|
||||
.input("tests/lstm/lstm_with_initial_state.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/where_op/where_op.onnx")
|
||||
.input("tests/where_op/where_op_broadcast.onnx")
|
||||
.input("tests/where_op/where_op_scalar_x.onnx")
|
||||
.input("tests/where_op/where_op_scalar_y.onnx")
|
||||
.input("tests/where_op/where_op_all_scalar.onnx")
|
||||
.input("tests/where_op/where_shape_all_shapes.onnx")
|
||||
.input("tests/where_op/where_shape_scalar_cond.onnx")
|
||||
.input("tests/where_op/where_shapes_from_inputs.onnx")
|
||||
.input("tests/where_op/where_static_shape.onnx")
|
||||
.input("tests/matmul/matmul.onnx")
|
||||
.input("tests/matmulinteger/matmulinteger.onnx")
|
||||
.input("tests/matmulinteger/matmulinteger_ranks.onnx")
|
||||
.input("tests/matmul/matmul_ranks.onnx")
|
||||
.input("tests/max/max.onnx")
|
||||
.input("tests/maxpool1d/maxpool1d.onnx")
|
||||
.input("tests/maxpool1d_ceil_mode/maxpool1d_ceil_mode.onnx")
|
||||
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||
.input("tests/maxpool2d_ceil_mode/maxpool2d_ceil_mode.onnx")
|
||||
.input("tests/min/min.onnx")
|
||||
.input("tests/mean/mean.onnx")
|
||||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/mul/mul_shape.onnx")
|
||||
.input("tests/mul/mul_shape_tensor.onnx")
|
||||
.input("tests/mul/mul_broadcast.onnx")
|
||||
.input("tests/neg/neg.onnx")
|
||||
.input("tests/not/not.onnx")
|
||||
.input("tests/one_hot/one_hot.onnx")
|
||||
.input("tests/or/or.onnx")
|
||||
.input("tests/or/or_scalar.onnx")
|
||||
.input("tests/or/or_broadcast.onnx")
|
||||
.input("tests/pad/pad.onnx")
|
||||
.input("tests/pad/pad_reflect.onnx")
|
||||
.input("tests/pad/pad_edge.onnx")
|
||||
.input("tests/pow/pow.onnx")
|
||||
.input("tests/pow/pow_int.onnx")
|
||||
.input("tests/prelu/prelu.onnx")
|
||||
.input("tests/prelu/prelu_with_channel_slope.onnx")
|
||||
.input("tests/random_normal/random_normal.onnx")
|
||||
.input("tests/random_normal_like/random_normal_like.onnx")
|
||||
.input("tests/random_uniform/random_uniform.onnx")
|
||||
.input("tests/random_uniform_like/random_uniform_like.onnx")
|
||||
.input("tests/range/range.onnx")
|
||||
.input("tests/range/range_static.onnx")
|
||||
.input("tests/range/range_mixed.onnx")
|
||||
.input("tests/range/range_runtime.onnx")
|
||||
.input("tests/recip/recip.onnx")
|
||||
.input("tests/reduce/reduce_max.onnx")
|
||||
.input("tests/reduce/reduce_max_bool.onnx")
|
||||
.input("tests/reduce/reduce_mean.onnx")
|
||||
.input("tests/reduce/reduce_mean_partial_shape.onnx")
|
||||
.input("tests/reduce/reduce_min.onnx")
|
||||
.input("tests/reduce/reduce_min_bool.onnx")
|
||||
.input("tests/reduce/reduce_prod.onnx")
|
||||
.input("tests/reduce/reduce_sum.onnx")
|
||||
.input("tests/reduce/reduce_sum_square.onnx")
|
||||
.input("tests/reduce/reduce_l1.onnx")
|
||||
.input("tests/reduce/reduce_l2.onnx")
|
||||
.input("tests/reduce/reduce_log_sum.onnx")
|
||||
.input("tests/reduce/reduce_log_sum_exp.onnx")
|
||||
.input("tests/relu/relu.onnx")
|
||||
.input("tests/reshape/reshape.onnx")
|
||||
.input("tests/reshape/reshape_with_1d_tensor.onnx")
|
||||
.input("tests/reshape/reshape_with_shape.onnx")
|
||||
.input("tests/reshape/reshape_to_scalar.onnx")
|
||||
.input("tests/reshape/reshape_3d_to_scalar.onnx")
|
||||
.input("tests/reshape/reshape_shape_to_shape.onnx")
|
||||
.input("tests/reshape/reshape_shape_with_neg.onnx")
|
||||
.input("tests/reshape/reshape_shape_partial.onnx")
|
||||
.input("tests/reshape/reshape_scalar_to_scalar.onnx")
|
||||
.input("tests/resize/resize_with_sizes.onnx")
|
||||
.input("tests/resize/resize_1d_linear_scale.onnx")
|
||||
.input("tests/resize/resize_1d_nearest_scale.onnx")
|
||||
.input("tests/resize/resize_2d_bicubic_scale.onnx")
|
||||
.input("tests/resize/resize_2d_bilinear_scale.onnx")
|
||||
.input("tests/resize/resize_2d_nearest_scale.onnx")
|
||||
.input("tests/resize/resize_with_shape.onnx")
|
||||
.input("tests/resize/resize_with_sizes_tensor.onnx")
|
||||
.input("tests/round/round.onnx")
|
||||
.input("tests/shape/shape.onnx")
|
||||
.input("tests/shape/shape_of_shape.onnx")
|
||||
.input("tests/shape/shape_slice.onnx")
|
||||
.input("tests/shape/shape_chain.onnx")
|
||||
.input("tests/sigmoid/sigmoid.onnx")
|
||||
.input("tests/sign/sign.onnx")
|
||||
.input("tests/sin/sin.onnx")
|
||||
.input("tests/sinh/sinh.onnx")
|
||||
.input("tests/size/size.onnx")
|
||||
.input("tests/slice/slice.onnx")
|
||||
.input("tests/slice/slice_shape.onnx")
|
||||
.input("tests/slice/slice_scalar.onnx")
|
||||
.input("tests/slice/slice_mixed.onnx")
|
||||
.input("tests/slice/slice_shape_gather.onnx")
|
||||
.input("tests/slice/slice_shape_runtime.onnx")
|
||||
.input("tests/slice/slice_shape_multi.onnx")
|
||||
.input("tests/slice/slice_shape_negative.onnx")
|
||||
.input("tests/slice/slice_shape_negative_range.onnx")
|
||||
.input("tests/slice/slice_1d_tensor.onnx")
|
||||
.input("tests/slice/slice_shape_start_tensor_end.onnx")
|
||||
.input("tests/slice/slice_tensor_start_shape_end.onnx")
|
||||
.input("tests/slice/slice_axes.onnx")
|
||||
.input("tests/slice/slice_with_steps.onnx")
|
||||
.input("tests/slice/slice_shape_with_steps.onnx")
|
||||
.input("tests/slice/slice_empty.onnx")
|
||||
.input("tests/softmax/softmax.onnx")
|
||||
.input("tests/space_to_depth/space_to_depth.onnx")
|
||||
.input("tests/sqrt/sqrt.onnx")
|
||||
.input("tests/squeeze/squeeze_multiple.onnx")
|
||||
.input("tests/squeeze/squeeze.onnx")
|
||||
.input("tests/squeeze/squeeze_shape.onnx")
|
||||
.input("tests/squeeze/squeeze_shape_noop.onnx")
|
||||
.input("tests/squeeze/squeeze_scalar.onnx")
|
||||
.input("tests/squeeze/squeeze_float.onnx")
|
||||
.input("tests/squeeze/squeeze_tensor_to_scalar.onnx")
|
||||
.input("tests/squeeze/squeeze_opset13_axes_input.onnx")
|
||||
.input("tests/squeeze/squeeze_no_axes.onnx")
|
||||
.input("tests/sub/sub.onnx")
|
||||
.input("tests/sub/sub_shape.onnx")
|
||||
.input("tests/sub/sub_broadcast.onnx")
|
||||
.input("tests/sub/sub_int.onnx")
|
||||
.input("tests/sub/sub_shape_tensor.onnx")
|
||||
.input("tests/sum/sum.onnx")
|
||||
.input("tests/sum/sum_int.onnx")
|
||||
.input("tests/tan/tan.onnx")
|
||||
.input("tests/tanh/tanh.onnx")
|
||||
.input("tests/tile/tile.onnx")
|
||||
.input("tests/topk/topk.onnx")
|
||||
.input("tests/trilu/trilu_upper.onnx")
|
||||
.input("tests/trilu/trilu_lower.onnx")
|
||||
.input("tests/transpose/transpose.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze_runtime_axes.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze_like.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze_int_to_shape.onnx")
|
||||
.input("tests/unsqueeze/squeeze_unsqueeze_roundtrip.onnx")
|
||||
.input("tests/split/split.onnx")
|
||||
.input("tests/xor/xor.onnx")
|
||||
.input("tests/xor/xor_scalar.onnx")
|
||||
.input("tests/xor/xor_broadcast.onnx")
|
||||
// If operator tests
|
||||
.input("tests/if_op/if_conv2d.onnx")
|
||||
.input("tests/if_op/if_linear.onnx")
|
||||
.input("tests/if_op/nested_if.onnx")
|
||||
// Loop operator tests
|
||||
.input("tests/loop/loop_simple.onnx")
|
||||
.input("tests/loop/loop_dynamic_cond.onnx")
|
||||
.input("tests/loop/loop_multi_deps.onnx")
|
||||
.input("tests/loop/loop_nested.onnx")
|
||||
.input("tests/loop/loop_scan_outputs.onnx")
|
||||
// Scan operator tests
|
||||
.input("tests/scan/scan_cumsum.onnx")
|
||||
.input("tests/scan/scan_reverse.onnx")
|
||||
.input("tests/scan/scan_multi_state.onnx")
|
||||
.input("tests/scan/scan_axis1.onnx")
|
||||
// Subgraph tests: nested control flow and outer-scope references
|
||||
.input("tests/subgraph/nested_if_loop_if.onnx")
|
||||
.input("tests/subgraph/nested_if_loop_if_scan.onnx")
|
||||
.input("tests/subgraph/outer_scope_ref.onnx")
|
||||
.input("tests/subgraph/outer_scope_multi_var.onnx")
|
||||
.input("tests/subgraph/outer_scope_loop.onnx")
|
||||
.input("tests/subgraph/outer_scope_scan.onnx")
|
||||
.input("tests/subgraph/outer_scope_constant.onnx")
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
|
||||
// Note: Previous record type variants (NamedMpk, PrettyJson, Bincode, etc.)
|
||||
// have been removed. All models now use burnpack format exclusively.
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
[project]
|
||||
name = "onnx-tests"
|
||||
version = "0.1.0"
|
||||
description = "project for testing ONNX support"
|
||||
authors = []
|
||||
dependencies = [
|
||||
"torch>=2.3.1",
|
||||
"onnx>=1.16.1",
|
||||
"onnxruntime>=1.18.0",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = []
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/onnx_tests"]
|
||||
@@ -1,75 +0,0 @@
|
||||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
|
||||
-e file:.
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
filelock==3.15.1
|
||||
# via torch
|
||||
flatbuffers==24.3.25
|
||||
# via onnxruntime
|
||||
fsspec==2024.6.0
|
||||
# via torch
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
jinja2==3.1.4
|
||||
# via torch
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
networkx==3.3
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
# via onnx
|
||||
# via onnxruntime
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.5.40
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
onnx==1.16.1
|
||||
# via onnx-tests
|
||||
onnxruntime==1.18.0
|
||||
# via onnx-tests
|
||||
packaging==24.1
|
||||
# via onnxruntime
|
||||
protobuf==5.27.1
|
||||
# via onnx
|
||||
# via onnxruntime
|
||||
sympy==1.12.1
|
||||
# via onnxruntime
|
||||
# via torch
|
||||
torch==2.3.1
|
||||
# via onnx-tests
|
||||
typing-extensions==4.12.2
|
||||
# via torch
|
||||
@@ -1,75 +0,0 @@
|
||||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
|
||||
-e file:.
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
filelock==3.15.1
|
||||
# via torch
|
||||
flatbuffers==24.3.25
|
||||
# via onnxruntime
|
||||
fsspec==2024.6.0
|
||||
# via torch
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
jinja2==3.1.4
|
||||
# via torch
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
networkx==3.3
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
# via onnx
|
||||
# via onnxruntime
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.5.40
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
onnx==1.16.1
|
||||
# via onnx-tests
|
||||
onnxruntime==1.18.0
|
||||
# via onnx-tests
|
||||
packaging==24.1
|
||||
# via onnxruntime
|
||||
protobuf==5.27.1
|
||||
# via onnx
|
||||
# via onnxruntime
|
||||
sympy==1.12.1
|
||||
# via onnxruntime
|
||||
# via torch
|
||||
torch==2.3.1
|
||||
# via onnx-tests
|
||||
typing-extensions==4.12.2
|
||||
# via torch
|
||||
@@ -1 +0,0 @@
|
||||
#![no_std]
|
||||
@@ -1,16 +0,0 @@
|
||||
pytorch2.8.0:m
|
||||
|
||||
onnx::Abs_01/Abs"Abs
|
||||
main_graphZ%
|
||||
onnx::Abs_0
|
||||
|
||||
|
||||
|
||||
|
||||
b
|
||||
1
|
||||
|
||||
|
||||
|
||||
|
||||
B
|
||||
@@ -1,40 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/abs/abs.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.abs(x)
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "abs.onnx"
|
||||
test_input = torch.tensor([[[[-1.0, -4.0, 9.0, -25.0]]]], device=device)
|
||||
|
||||
torch.onnx.export(model, (test_input), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data: {}".format(test_input))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,27 +0,0 @@
|
||||
// Import the shared macro
|
||||
use crate::include_models;
|
||||
include_models!(abs);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn::tensor::{Tensor, TensorData, Tolerance, ops::FloatElem};
|
||||
|
||||
use crate::backend::TestBackend;
|
||||
type FT = FloatElem<TestBackend>;
|
||||
|
||||
#[test]
|
||||
fn abs() {
|
||||
let device = Default::default();
|
||||
let model: abs::Model<TestBackend> = abs::Model::new(&device);
|
||||
|
||||
let input = Tensor::<TestBackend, 4>::from_floats([[[[-1.0, -4.0, 9.0, -25.0]]]], &device);
|
||||
|
||||
let output = model.forward(input);
|
||||
let expected = TensorData::from([[[[1.0f32, 4.0, 9.0, 25.0]]]]);
|
||||
|
||||
output
|
||||
.to_data()
|
||||
.assert_approx_eq::<FT>(&expected, Tolerance::default());
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -1,57 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/add/add.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
# Declare a constant float tensor with ones
|
||||
self.a = torch.ones(1, 1, 1, 4)
|
||||
|
||||
# Declare a scalar
|
||||
self.b = 5.0
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, k):
|
||||
|
||||
# Add a tensor input and a constant tensor
|
||||
x = x + self.a
|
||||
|
||||
# Add a scalar constant and a scalar input
|
||||
d = self.b + k
|
||||
|
||||
# Add a tensor and a scalar
|
||||
x = x + d
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "add.onnx"
|
||||
dummy_input = torch.randn(1, 2, 3, 4, device=device)
|
||||
|
||||
scalar = 2.0
|
||||
|
||||
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
|
||||
|
||||
print("Test input data: {}, {}".format(test_input, scalar))
|
||||
output = model.forward(test_input, scalar)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user