Move ONNX crates to burn-onnx repository (#4393)

* Move ONNX inference MNIST example to burn-onnx

Moving to https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference

* Move Raspberry Pi Pico example project to burn-onnx

Moving to https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico

* Move image-classification-web example to burn-onnx

Moving to https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web

* Refactor import-model-weights example to use burn-store

Replace burn-import with burn-store for loading PyTorch and Safetensors
model weights. Convert from NamedMpk (.mpk) format to Burnpack (.bpk)
format for native Burn model storage.

- Use PytorchStore and SafetensorsStore for weight loading
- Use BurnpackStore for saving/loading converted models
- Rename namedmpk binary to burnpack
- Add PyTorchToBurnAdapter for Safetensors files exported from PyTorch
- Update inference to accept Model directly instead of ModelRecord

* Remove ONNX to Burn development guide from contributor-book

The guide has been moved to the burn-onnx repository:
https://github.com/tracel-ai/burn-onnx/blob/main/DEVELOPMENT-GUIDE.md

* Update Cargo.lock

* Move pytorch-tests and safetensors-tests from burn-import to burn-store

Migrate test directories from crates/burn-import/ to crates/burn-store/
and update all tests to use the new burn-store API.

Changes:
- Update Cargo.toml dependencies from burn-import to burn-store
- Replace PyTorchFileRecorder/SafetensorsFileRecorder with PytorchStore/SafetensorsStore
- Convert record-based loading to direct model loading via ModuleSnapshot::load_from()
- Convert LoadArgs options to fluent builder pattern (.with_key_remapping(), etc.)
- Fix model configurations to match actual PyTorch weight file dimensions
- Add init() methods for models that previously only had new_with(record)

All 37 pytorch tests and 1 safetensors test pass.

* Fix test file paths after moving test directories to burn-store

Update paths in burn-store/src tests to reference the new locations
of pytorch-tests and safetensors-tests directories.

* Add migration guide for burn-import to burn-store

Document migration path from deprecated PyTorchFileRecorder and
SafetensorsFileRecorder to the new PytorchStore and SafetensorsStore APIs.
Cover API mapping, code examples, and common migration issues.

Include details on printing LoadResult for debugging and helpful suggestions.

* Remove burn-import crate (moved to burn-onnx repo)

The burn-import crate has been moved to a separate repository:
https://github.com/tracel-ai/burn-onnx/tree/main/crates/burn-import

For loading PyTorch and SafeTensors model weights, use burn-store instead
with PytorchStore and SafetensorsStore.

* Replace pytorch/safetensors docs with unified model-weights page

Consolidate PyTorch and SafeTensors model import documentation into a
single comprehensive page covering burn-store usage. The new page covers:

- All supported formats (Burnpack, SafeTensors, PyTorch)
- Loading and saving workflows
- Advanced features (filtering, remapping, partial loading, zero-copy)
- API reference and troubleshooting

* Simplify burn-store README and point to Burn Book

* Consolidate model weights docs into saving-and-loading page

Merge the model-weights documentation into the main saving-and-loading
page, providing a single comprehensive guide for all model persistence
operations using burn-store.

- Remove separate import/model-weights.md page
- Update saving-and-loading.md with full burn-store documentation
- Remove unused SVG images from import folder
- Update navigation in SUMMARY.md and import/README.md

* Move ONNX import to standalone section, point to burn-onnx repo

- Create new onnx-import.md as standalone top-level section
- Update links to point to burn-onnx repo (github.com/tracel-ai/burn-onnx)
- Remove import/ folder (no longer needed)
- Streamline documentation with quick start focus

* Restore full ONNX import documentation content

Restored detailed content including:
- Understanding ONNX section
- Burn's ONNX Support advantages
- ONNX compatibility and opset version guidance
- Step-by-step guide with code examples
- Advanced configuration options
- Troubleshooting section
- Examples and resources with links to burn-onnx repo

* Update ONNX examples links to burn-onnx repo

* Update ONNX-related example links in examples.md to burn-onnx repo

* Fix Burnpack extension: .burnpack -> .bpk

* Remove ONNX Tests README reference from docs

* Improve table formatting in ONNX import and saving docs

Reformats markdown tables in onnx-import.md and saving-and-loading.md for better readability and consistency. No content changes, only improved alignment and line breaks.

* Remove onnx-ir and onnx-ir-derive crates

Moved to https://github.com/tracel-ai/burn-onnx

* Remove burn-onnx crate

Moved to https://github.com/tracel-ai/burn-onnx

* Remove ONNX references from workspace config and README

- Remove burn-onnx entries from Cargo.toml workspace members/exclude
- Remove RUSTSEC-2024-0437 (protobuf) from audit.toml
- Update README ONNX section to point to burn-onnx repo
- Update README model loading section to point to new docs
- Remove ONNX example from README (moved to burn-onnx)

* Clean up remaining ONNX references

- Remove ONNX publish jobs from publish.yml
- Update semver-checks exclude list
- Update burn-book overview.md links
- Update burn-book no-std.md links to burn-onnx repo
- Remove ONNX proto license from NOTICES.md
- Update .gitignore comment

* Remove unused protobuf and rust-format dependencies

* Remove onnx-tests example from contributor-book

* Update Cargo.lock

* Format long line in test for readability

Split a long line in the should_fail_if_struct_field_is_missing test to improve code readability. No functional changes were made.

* Restore Record-based saving/loading sections per PR feedback

* Explain burn-store motivation in saving-and-loading docs

* Remove em dash from burn-store intro

* Format long line in test for readability
This commit is contained in:
Dilshod Tadjibaev
2026-01-28 14:38:11 -06:00
committed by GitHub
parent 1cda8e14f0
commit 933fdf4f69
1428 changed files with 1256 additions and 124595 deletions

View File

@@ -8,7 +8,6 @@
[advisories]
ignore = [
"RUSTSEC-2024-0437", # Protobuf used in ONNX graph parsing.
"RUSTSEC-2024-0436", # Paste used to generate macro, should be removed at some point.
"RUSTSEC-2025-0119", # `number_prefix` used by `tokenizers`, only in the examples.
"RUSTSEC-2025-0141", # `bincode` is no longer maintained.

View File

@@ -386,49 +386,6 @@ jobs:
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-onnx:
needs:
- publish-burn
- publish-onnx-ir
- publish-burn-store
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v7
with:
crate: burn-onnx
dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-import:
needs:
- publish-burn
- publish-burn-store
- publish-burn-onnx
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v7
with:
crate: burn-import
dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-onnx-ir-derive:
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v7
with:
crate: onnx-ir-derive
dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-onnx-ir:
needs:
- publish-burn-tensor
- publish-onnx-ir-derive
uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v7
with:
crate: onnx-ir
dry-run-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run-only || false }}
secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-store:
needs:
- publish-burn-core

View File

@@ -26,4 +26,4 @@ jobs:
# publishes on crates.io with `default-features`
feature-group: default-features
# Exclude crates which are not published on crates.io
exclude: burn-no-std-tests,onnx-tests,pytorch-tests
exclude: burn-no-std-tests,pytorch-tests,safetensors-tests

2
.gitignore vendored
View File

@@ -10,7 +10,7 @@ target
.fleet
.ipynb_checkpoints/
# Generated IR and Burn Graph from ONNX
# Build output directory
out
# Virtual Environment of Python

300
Cargo.lock generated
View File

@@ -463,7 +463,7 @@ dependencies = [
"clang-sys",
"itertools 0.13.0",
"log",
"prettyplease 0.2.37",
"prettyplease",
"proc-macro2",
"quote",
"regex",
@@ -610,7 +610,7 @@ checksum = "89ec27229c38ed0eb3c0feee3d2c1d6a4379ae44f418a29a658890e062d8f365"
dependencies = [
"darling 0.23.0",
"ident_case",
"prettyplease 0.2.37",
"prettyplease",
"proc-macro2",
"quote",
"rustversion",
@@ -952,23 +952,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "burn-import"
version = "0.21.0"
dependencies = [
"burn",
"burn-ndarray",
"burn-onnx",
"burn-store",
"candle-core",
"derive-new",
"regex",
"serde",
"serde_json",
"thiserror 2.0.17",
"zip 7.2.0",
]
[[package]]
name = "burn-ir"
version = "0.21.0"
@@ -1033,25 +1016,6 @@ dependencies = [
"burn-store",
]
[[package]]
name = "burn-onnx"
version = "0.21.0"
dependencies = [
"burn",
"burn-ndarray",
"burn-store",
"derive-new",
"insta",
"log",
"onnx-ir",
"proc-macro2",
"quote",
"rust-format",
"syn 2.0.114",
"tracing-core",
"tracing-subscriber",
]
[[package]]
name = "burn-optim"
version = "0.21.0"
@@ -1974,7 +1938,7 @@ dependencies = [
"document-features",
"mio",
"parking_lot",
"rustix 1.1.3",
"rustix",
"signal-hook 0.3.18",
"signal-hook-mio",
"winapi",
@@ -2240,7 +2204,7 @@ dependencies = [
"darling 0.21.3",
"derive-new",
"ident_case",
"prettyplease 0.2.37",
"prettyplease",
"proc-macro2",
"quote",
"syn 2.0.114",
@@ -2811,12 +2775,6 @@ version = "1.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abd57806937c9cc163efc8ea3910e00a62e2aeb0b8119f1793a978088f8f6b04"
[[package]]
name = "diff"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
[[package]]
name = "digest"
version = "0.10.7"
@@ -3475,7 +3433,7 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4"
dependencies = [
"rustix 1.1.3",
"rustix",
"windows-sys 0.59.0",
]
@@ -4513,27 +4471,6 @@ dependencies = [
"zune-jpeg 0.5.8",
]
[[package]]
name = "image-classification-web"
version = "0.21.0"
dependencies = [
"burn",
"burn-candle",
"burn-onnx",
"burn-store",
"console_error_panic_hook",
"getrandom 0.3.4",
"js-sys",
"log",
"serde",
"serde-wasm-bindgen",
"serde_json",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-logger",
"web-time",
]
[[package]]
name = "image-webp"
version = "0.2.4"
@@ -4555,7 +4492,7 @@ name = "import-model-weights"
version = "0.21.0"
dependencies = [
"burn",
"burn-import",
"burn-store",
]
[[package]]
@@ -4614,18 +4551,6 @@ dependencies = [
"generic-array",
]
[[package]]
name = "insta"
version = "1.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b66886d14d18d420ab5052cbff544fc5d34d0b2cdd35eb5976aaa10a4a472e5"
dependencies = [
"console 0.15.11",
"once_cell",
"similar",
"tempfile",
]
[[package]]
name = "instability"
version = "0.3.11"
@@ -4895,12 +4820,6 @@ dependencies = [
"bitflags 2.10.0",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
@@ -5803,59 +5722,6 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "onnx-inference"
version = "0.21.0"
dependencies = [
"burn",
"burn-onnx",
"burn-store",
"serde",
]
[[package]]
name = "onnx-ir"
version = "0.21.0"
dependencies = [
"burn-tensor",
"bytemuck",
"bytes",
"derive-new",
"divan",
"half",
"log",
"memmap2",
"onnx-ir-derive",
"pretty_assertions",
"protobuf",
"protobuf-codegen",
"regex",
"rstest",
"serde",
"strum",
"tempfile",
]
[[package]]
name = "onnx-ir-derive"
version = "0.21.0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "onnx-tests"
version = "0.21.0"
dependencies = [
"burn",
"burn-onnx",
"burn-store",
"float-cmp",
"serde",
]
[[package]]
name = "openblas-build"
version = "0.10.14"
@@ -6884,26 +6750,6 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa"
[[package]]
name = "pretty_assertions"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
dependencies = [
"diff",
"yansi",
]
[[package]]
name = "prettyplease"
version = "0.1.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86"
dependencies = [
"proc-macro2",
"syn 1.0.109",
]
[[package]]
name = "prettyplease"
version = "0.2.37"
@@ -6974,58 +6820,6 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "protobuf"
version = "3.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d65a1d4ddae7d8b5de68153b48f6aa3bba8cb002b243dbdbc55a5afbc98f99f4"
dependencies = [
"bytes",
"once_cell",
"protobuf-support",
"thiserror 1.0.69",
]
[[package]]
name = "protobuf-codegen"
version = "3.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d3976825c0014bbd2f3b34f0001876604fe87e0c86cd8fa54251530f1544ace"
dependencies = [
"anyhow",
"once_cell",
"protobuf",
"protobuf-parse",
"regex",
"tempfile",
"thiserror 1.0.69",
]
[[package]]
name = "protobuf-parse"
version = "3.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4aeaa1f2460f1d348eeaeed86aea999ce98c1bded6f089ff8514c9d9dbdc973"
dependencies = [
"anyhow",
"indexmap",
"log",
"protobuf",
"protobuf-support",
"tempfile",
"thiserror 1.0.69",
"which",
]
[[package]]
name = "protobuf-support"
version = "3.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e36c2f31e0a47f9280fb347ef5e461ffcd2c52dd520d8e216b52f93b0b0d7d6"
dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "psm"
version = "0.1.28"
@@ -7077,8 +6871,8 @@ version = "0.21.0"
dependencies = [
"burn",
"burn-autodiff",
"burn-import",
"burn-ndarray",
"burn-store",
"float-cmp",
"serde",
]
@@ -7708,17 +7502,6 @@ dependencies = [
"smallvec",
]
[[package]]
name = "rust-format"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60e7c00b6c3bf5e38a880eec01d7e829d12ca682079f8238a464def3c4b31627"
dependencies = [
"prettyplease 0.1.25",
"proc-macro2",
"syn 1.0.109",
]
[[package]]
name = "rustc-demangle"
version = "0.1.26"
@@ -7746,19 +7529,6 @@ dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.59.0",
]
[[package]]
name = "rustix"
version = "1.1.3"
@@ -7768,7 +7538,7 @@ dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys 0.11.0",
"linux-raw-sys",
"windows-sys 0.61.2",
]
@@ -7869,8 +7639,8 @@ version = "0.21.0"
dependencies = [
"burn",
"burn-autodiff",
"burn-import",
"burn-ndarray",
"burn-store",
"float-cmp",
"serde",
]
@@ -7990,17 +7760,6 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-wasm-bindgen"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b"
dependencies = [
"js-sys",
"serde",
"wasm-bindgen",
]
[[package]]
name = "serde_bytes"
version = "0.11.19"
@@ -8231,12 +7990,6 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
[[package]]
name = "similar"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
[[package]]
name = "simple-regression"
version = "0.21.0"
@@ -8579,7 +8332,7 @@ dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix 1.1.3",
"rustix",
"windows-sys 0.61.2",
]
@@ -8598,7 +8351,7 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b8cb979cb11c32ce1603f8137b22262a9d131aaa5c37b5678025f22b8becd0"
dependencies = [
"rustix 1.1.3",
"rustix",
"windows-sys 0.60.2",
]
@@ -9818,17 +9571,6 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-logger"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "074649a66bb306c8f2068c9016395fa65d8e08d2affcbf95acf3c24c3ab19718"
dependencies = [
"log",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
@@ -10125,18 +9867,6 @@ dependencies = [
"web-sys",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix 0.38.44",
]
[[package]]
name = "widestring"
version = "1.2.1"
@@ -10489,7 +10219,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156"
dependencies = [
"libc",
"rustix 1.1.3",
"rustix",
]
[[package]]
@@ -10520,12 +10250,6 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a5a4b21e1a62b67a2970e6831bc091d7b87e119e7f9791aef9702e3bef04448"
[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "yoke"
version = "0.7.5"

View File

@@ -6,9 +6,8 @@ resolver = "2"
members = [
"crates/*",
"crates/burn-import/pytorch-tests",
"crates/burn-import/safetensors-tests",
"crates/burn-onnx/onnx-tests",
"crates/burn-store/pytorch-tests",
"crates/burn-store/safetensors-tests",
"crates/burn-collective/multinode-tests",
"examples/*",
"xtask",
@@ -17,7 +16,6 @@ members = [
exclude = [
"examples/notebook",
"examples/raspberry-pi-pico",
"crates/burn-onnx/model-checks/*", # exclude model checking crates from workspace
]
[workspace.package]
@@ -75,8 +73,6 @@ planus = { version = "=1.1" }
polars = { version = "0.51.0", features = ["lazy"] }
pretty_assertions = "1.4.1"
proc-macro2 = "1.0.106"
protobuf = "3.7.2"
protobuf-codegen = "3.7.2"
quote = "1.0.42"
r2d2 = "0.8.10"
r2d2_sqlite = "0.31.0"
@@ -91,7 +87,6 @@ reqwest = { version = "0.12.23", default-features = false, features = [
rmp-serde = { version = "1.3.1", default-features = false }
rstest = "0.26.1"
rusqlite = "0.37.0"
rust-format = "0.3.4"
sanitize-filename = "0.6.0"
serde_bytes = { version = "0.11.18", default-features = false, features = [
"alloc",

View File

@@ -38,214 +38,6 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
## ONNX
**Source**: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
License: Apache License 2.0
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
## wgpu
**Source:** https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml

View File

@@ -251,16 +251,17 @@ ONNX Support 🐫
</summary>
<br />
Burn supports importing ONNX (Open Neural Network Exchange) models, allowing you to easily port
models from TensorFlow or PyTorch to Burn. The ONNX model is converted into Rust code that uses
Burn's native APIs, enabling the imported model to run on any Burn backend (CPU, GPU, WebAssembly)
and benefit from all of Burn's optimizations like automatic kernel fusion.
Burn supports importing ONNX (Open Neural Network Exchange) models through the
[burn-onnx](https://github.com/tracel-ai/burn-onnx) crate, allowing you to easily port models from
TensorFlow or PyTorch to Burn. The ONNX model is converted into Rust code that uses Burn's native
APIs, enabling the imported model to run on any Burn backend (CPU, GPU, WebAssembly) and benefit
from all of Burn's optimizations like automatic kernel fusion.
Our ONNX support is further described in
[this section of the Burn Book 🔥](https://burn.dev/books/burn/import/onnx-model.html).
[this section of the Burn Book 🔥](https://burn.dev/books/burn/onnx-import.html).
> **Note**: This crate is in active development and currently supports a
> [limited set of ONNX operators](./crates/burn-onnx/SUPPORTED-ONNX-OPS.md).
> [limited set of ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md).
</details>
@@ -274,10 +275,8 @@ You can load weights from PyTorch or Safetensors formats directly into your Burn
This makes it easy to reuse existing models while benefiting from Burn's performance and deployment
features.
Learn more:
- [Import pre-trained PyTorch models into Burn](https://burn.dev/books/burn/import/pytorch-model.html)
- [Load models from Safetensors format](https://burn.dev/books/burn/import/safetensors-model.html)
Learn more in the [Saving & Loading Models](https://burn.dev/books/burn/saving-and-loading.html)
section of the Burn Book.
</details>
@@ -423,8 +422,6 @@ Additional examples:
`Learner` configured to log metrics and keep training checkpoints.
- [Named Tensor](./examples/named-tensor) : Performs operations with the experimental `NamedTensor`
feature.
- [ONNX Import Inference](./examples/onnx-inference) : Imports an ONNX model pre-trained on MNIST to
perform inference on a sample image with Burn.
- [PyTorch Import Inference](./examples/import-model-weights) : Imports a PyTorch model pre-trained
on MNIST to perform inference on a sample image with Burn.
- [Text Classification](./examples/text-classification) : Trains a text classification transformer

View File

@@ -27,10 +27,7 @@
- [Distributed Computing](./performance/distributed-computing.md)
- [Custom Training Loop](./custom-training-loop.md)
- [Saving & Loading Models](./saving-and-loading.md)
- [Importing Models](./import/README.md)
- [ONNX Model](./import/onnx-model.md)
- [PyTorch Model](./import/pytorch-model.md)
- [Safetensors Model](./import/safetensors-model.md)
- [ONNX Import](./onnx-import.md)
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md)
- [Advanced](./advanced/README.md)
- [Backend Extension](./advanced/backend-extension/README.md)

View File

@@ -1,7 +1,9 @@
# No Standard Library
In this section, you will learn how to run an onnx inference model on an embedded system, with no standard library support on a Raspberry Pi Pico 2. This should be universally applicable to other platforms. All the code can be found under the
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/raspberry-pi-pico).
In this section, you will learn how to run an ONNX inference model on an embedded system, with no
standard library support on a Raspberry Pi Pico 2. This should be universally applicable to other
platforms. All the code can be found in the
[burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico).
## Step-by-Step Guide
@@ -31,7 +33,7 @@ burn-onnx = { version = "0.21" } # Used to auto generate the rust code to import
```
### Import the Model
Follow the directions to [import models](../import/README.md).
Follow the directions in [ONNX Import](../onnx-import.md).
Use the following ModelGen config
```rs

View File

@@ -77,11 +77,11 @@ The following additional examples are currently available if you want to check t
| [Regression](https://github.com/tracel-ai/burn/tree/main/examples/simple-regression) | Trains a simple MLP on the California Housing dataset to predict the median house value for a district. |
| [Custom Image Dataset](https://github.com/tracel-ai/burn/tree/main/examples/custom-image-dataset) | Trains a simple CNN on custom image dataset following a simple folder structure. |
| [Custom Renderer](https://github.com/tracel-ai/burn/tree/main/examples/custom-renderer) | Implements a custom renderer to display the [`Learner`](./building-blocks/learner.md) progress. |
| [Image Classification Web](https://github.com/tracel-ai/burn/tree/main/examples/image-classification-web) | Image classification web browser demo using Burn, WGPU and WebAssembly. |
| [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) | Image classification web browser demo using Burn, WGPU and WebAssembly. |
| [MNIST Inference on Web](https://github.com/tracel-ai/burn/tree/main/examples/mnist-inference-web) | An interactive MNIST inference demo in the browser. The demo is available [online](https://burn.dev/demo/). |
| [MNIST Training](https://github.com/tracel-ai/burn/tree/main/examples/mnist) | Demonstrates how to train a custom [`Module`](./building-blocks/module.md) (MLP) with the [`Learner`](./building-blocks/learner.md) configured to log metrics and keep training checkpoints. |
| [Named Tensor](https://github.com/tracel-ai/burn/tree/main/examples/named-tensor) | Performs operations with the experimental `NamedTensor` feature. |
| [ONNX Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference) | Imports an ONNX model pre-trained on MNIST to perform inference on a sample image with Burn. |
| [ONNX Import Inference](https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference) | Imports an ONNX model pre-trained on MNIST to perform inference on a sample image with Burn. |
| [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/import-model-weights) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. |
| [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. |
| [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. |

View File

@@ -1,13 +0,0 @@
# Importing Models
Burn supports importing models from other frameworks and file formats, enabling you to use pre-trained weights in your Burn applications.
## Supported Formats
Burn currently supports three primary model import formats:
| Format | Description | Use Case |
|--------|-------------|----------|
| [**ONNX**](./onnx-model.md) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export |
| [**PyTorch**](./pytorch-model.md) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture |
| [**Safetensors**](./safetensors-model.md) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture |

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" id="export" class="canvas" preserveAspectRatio="xMidYMid meet" style="" width="112" height="124"><rect id="background" fill="#fff" pointer-events="all" width="112" height="124"/><g id="origin" transform="translate(5.077343750000001, 5.077343750000001) scale(1)"><g id="clusters" class="clusters"/><g id="edge-paths" class="edge-paths"><defs><marker id="arrowhead" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto" style="fill: rgb(0, 0, 0);"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker><marker id="arrowhead-select" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto" style="fill: rgb(238, 0, 0);"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker><marker id="arrowhead-hover" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker></defs></g><g id="edge-labels" class="edge-labels"/><g id="nodes" class="nodes"><g id="node-name-conv1" class="node graph-node" transform="translate(0,60)" style=""><g class="node-item node-item-type" transform="translate(0,0)"><path d="M5,0h91.546875a5,5 0 0 1 5,5v16a0,0 0 0 1 0,0h-101.546875a0,0 0 0 1 0,0v-16a5,5 0 0 1 5,-5z" style="stroke: rgb(0, 0, 0); fill: rgb(0, 0, 0); stroke-width: 0;"/><text x="6" y="15" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 11px; text-rendering: geometricprecision; user-select: none; fill: rgb(255, 255, 255);">conv1</text><title>?</title></g><g class="node-attribute-list" transform="translate(0,21)"><path d="M0,0h101.546875a0,0 0 0 1 0,0v27a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-27a0,0 0 0 1 0,0z" style="stroke: rgb(0, 0, 0); fill: rgb(255, 255, 255); stroke-width: 0;"/><g class="node-attribute"><text xml:space="preserve" x="6" y="13" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2,2,2,2]</title><tspan style="font-weight: bold;">weight</tspan><tspan>〈2×2×2×2〉</tspan></text></g><g class="node-attribute"><text xml:space="preserve" x="6" y="26" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2]</title><tspan style="font-weight: bold;">bias</tspan><tspan>〈2〉</tspan></text></g><line class="node" x1="0" x2="101.546875" y1="0" y2="0" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><path class="node node-border" d="M5,0h91.546875a5,5 0 0 1 5,5v43a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-43a5,5 0 0 1 5,-5z" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><g id="node-name-conv2" class="node graph-node" transform="translate(0,0)" style=""><g class="node-item node-item-type" transform="translate(0,0)"><path d="M5,0h91.546875a5,5 0 0 1 5,5v16a0,0 0 0 1 0,0h-101.546875a0,0 0 0 1 0,0v-16a5,5 0 0 1 5,-5z" style="stroke: rgb(0, 0, 0); fill: rgb(0, 0, 0); stroke-width: 0;"/><text x="6" y="15" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 11px; text-rendering: geometricprecision; user-select: none; fill: rgb(255, 255, 255);">conv2</text><title>?</title></g><g class="node-attribute-list" transform="translate(0,21)"><path d="M0,0h101.546875a0,0 0 0 1 0,0v14a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-14a0,0 0 0 1 0,0z" style="stroke: rgb(0, 0, 0); fill: rgb(255, 255, 255); stroke-width: 0;"/><g class="node-attribute"><text xml:space="preserve" x="6" y="13" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2,2,2,2]</title><tspan style="font-weight: bold;">weight</tspan><tspan>〈2×2×2×2〉</tspan></text></g><line class="node" x1="0" x2="101.546875" y1="0" y2="0" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><path class="node node-border" d="M5,0h91.546875a5,5 0 0 1 5,5v30a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-30a5,5 0 0 1 5,-5z" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g></g></g></svg>

Before

Width:  |  Height:  |  Size: 4.8 KiB

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" id="export" class="canvas" preserveAspectRatio="xMidYMid meet" style="" width="112" height="124"><rect id="background" fill="#fff" pointer-events="all" width="112" height="124"/><g id="origin" transform="translate(5.077343750000001, 5.077343750000001) scale(1)"><g id="clusters" class="clusters"/><g id="edge-paths" class="edge-paths"><defs><marker id="arrowhead" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto" style="fill: rgb(0, 0, 0);"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker><marker id="arrowhead-select" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto" style="fill: rgb(238, 0, 0);"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker><marker id="arrowhead-hover" viewBox="0 0 10 10" refX="9" refY="5" markerUnits="strokeWidth" markerWidth="8" markerHeight="6" orient="auto"><path d="M 0 0 L 10 5 L 0 10 L 4 5 z" style="stroke-width: 1;"/></marker></defs></g><g id="edge-labels" class="edge-labels"/><g id="nodes" class="nodes"><g id="node-name-conv.conv1" class="node graph-node" transform="translate(0,60)" style=""><g class="node-item node-item-type" transform="translate(0,0)"><path d="M5,0h91.546875a5,5 0 0 1 5,5v16a0,0 0 0 1 0,0h-101.546875a0,0 0 0 1 0,0v-16a5,5 0 0 1 5,-5z" style="stroke: rgb(0, 0, 0); fill: rgb(0, 0, 0); stroke-width: 0;"/><text x="6" y="15" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 11px; text-rendering: geometricprecision; user-select: none; fill: rgb(255, 255, 255);">conv.conv1</text><title>?</title></g><g class="node-attribute-list" transform="translate(0,21)"><path d="M0,0h101.546875a0,0 0 0 1 0,0v27a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-27a0,0 0 0 1 0,0z" style="stroke: rgb(0, 0, 0); fill: rgb(255, 255, 255); stroke-width: 0;"/><g class="node-attribute"><text xml:space="preserve" x="6" y="13" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2,2,2,2]</title><tspan style="font-weight: bold;">weight</tspan><tspan>〈2×2×2×2〉</tspan></text></g><g class="node-attribute"><text xml:space="preserve" x="6" y="26" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2]</title><tspan style="font-weight: bold;">bias</tspan><tspan>〈2〉</tspan></text></g><line class="node" x1="0" x2="101.546875" y1="0" y2="0" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><path class="node node-border" d="M5,0h91.546875a5,5 0 0 1 5,5v43a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-43a5,5 0 0 1 5,-5z" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><g id="node-name-conv.conv2" class="node graph-node" transform="translate(0,0)" style=""><g class="node-item node-item-type" transform="translate(0,0)"><path d="M5,0h91.546875a5,5 0 0 1 5,5v16a0,0 0 0 1 0,0h-101.546875a0,0 0 0 1 0,0v-16a5,5 0 0 1 5,-5z" style="stroke: rgb(0, 0, 0); fill: rgb(0, 0, 0); stroke-width: 0;"/><text x="6" y="15" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 11px; text-rendering: geometricprecision; user-select: none; fill: rgb(255, 255, 255);">conv.conv2</text><title>?</title></g><g class="node-attribute-list" transform="translate(0,21)"><path d="M0,0h101.546875a0,0 0 0 1 0,0v14a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-14a0,0 0 0 1 0,0z" style="stroke: rgb(0, 0, 0); fill: rgb(255, 255, 255); stroke-width: 0;"/><g class="node-attribute"><text xml:space="preserve" x="6" y="13" style="font-family: -apple-system, BlinkMacSystemFont, &quot;Segoe WPC&quot;, &quot;Segoe UI&quot;, Ubuntu, &quot;Droid Sans&quot;, sans-serif, &quot;PingFang SC&quot;; font-size: 9px; font-weight: normal; text-rendering: geometricprecision; user-select: none;"><title>float32[2,2,2,2]</title><tspan style="font-weight: bold;">weight</tspan><tspan>〈2×2×2×2〉</tspan></text></g><line class="node" x1="0" x2="101.546875" y1="0" y2="0" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g><path class="node node-border" d="M5,0h91.546875a5,5 0 0 1 5,5v30a5,5 0 0 1 -5,5h-91.546875a5,5 0 0 1 -5,-5v-30a5,5 0 0 1 5,-5z" style="stroke: rgb(51, 51, 51); fill: none; stroke-width: 1px;"/></g></g></g></svg>

Before

Width:  |  Height:  |  Size: 4.8 KiB

View File

@@ -1,344 +0,0 @@
# PyTorch Model
## Introduction
Burn supports importing model weights from PyTorch, whether you've trained your model in PyTorch or
want to use a pre-trained model. Burn supports importing PyTorch model weights with `.pt` and
`.safetensors` file extensions. Compared to ONNX models, these files only contain the weights of the
model, so you will need to reconstruct the model architecture in Burn.
This guide demonstrates the complete workflow for exporting models from PyTorch and importing them
into Burn. You can also refer to this
[Transitioning From PyTorch to Burn](https://dev.to/laggui/transitioning-from-pytorch-to-burn-45m)
tutorial for importing a more complex model.
## Exporting Models to PyTorch Format
To export a PyTorch model correctly, you need to save only the model weights (state_dict) using the
`torch.save` function, not the entire model.
### Example: Exporting a PyTorch Model
```python
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
if __name__ == "__main__":
# Set seed for reproducibility
torch.manual_seed(42)
# Initialize model and ensure it's on CPU
model = Net().to(torch.device("cpu"))
# Extract model weights dictionary
model_weights = model.state_dict()
# Save only the weights, not the entire model
torch.save(model_weights, "conv2d.pt")
```
If you accidentally save the entire model instead of just the weights, you may encounter errors
during import like:
```
Failed to decode foobar: DeserializeError("Serde error: other error:
Missing source values for the 'foo1' field of type 'BarRecordItem'.
Please verify the source data and ensure the field name is correct")
```
### Verifying the Export
You can verify your exported model by viewing the `.pt` file in
[Netron](https://github.com/lutzroeder/netron), a neural network visualization tool. A properly
exported weights file will show a flat structure of tensors, while an incorrectly exported file will
display nested blocks representing the entire model architecture.
When viewing the exported model in Netron, you should see something like this:
![image alt>](./conv2d.svg)
## Importing PyTorch Models into Burn
Importing a PyTorch model into Burn involves two main steps:
1. Defining the model architecture in Burn
2. Loading the weights from the exported PyTorch model
### Step 1: Define the Model in Burn
First, you need to create a Burn model that matches the architecture of the model you exported:
```rust
use burn::{
nn::conv::{Conv2d, Conv2dConfig},
prelude::*,
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let conv1 = Conv2dConfig::new([2, 2], [2, 2])
.init(device);
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
```
### Step 2: Load the Weights
You have two options for loading the weights:
#### Option A: Load Dynamically at Runtime
This approach loads the PyTorch file directly at runtime, requiring the `burn-import` dependency:
```rust
use crate::model;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
type Backend = burn_ndarray::NdArray<f32>;
fn main() {
let device = Default::default();
// Load weights from PyTorch file
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("./conv2d.pt".into(), &device)
.expect("Should decode state successfully");
// Initialize model and load weights
let model = model::Net::<Backend>::init(&device).load_record(record);
}
```
#### Option B: Pre-convert to Burn's Binary Format
This approach converts the PyTorch file to Burn's optimized binary format during build time,
removing the runtime dependency on `burn-import`:
```rust
// This code would go in build.rs or a separate tool
use crate::model;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
type Backend = burn_ndarray::NdArray<f32>;
fn convert_model() {
let device = Default::default();
// Load from PyTorch
let recorder = PyTorchFileRecorder::<FullPrecisionSettings>::default();
let record = recorder
.load("./conv2d.pt".into(), &device)
.expect("Should decode state successfully");
// Save to Burn's binary format
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
recorder
.record(record, "model.mpk".into())
.expect("Failed to save model record");
}
// In your application code
fn load_model() -> Net<Backend> {
let device = Default::default();
// Load from Burn's binary format
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
.load("./model.mpk".into(), &device)
.expect("Should decode state successfully");
Net::<Backend>::init(&device).load_record(record)
}
```
> **Note**: For examples of pre-converting models, see the `examples/import-model-weights` directory
> in the Burn repository.
## Extract Configuration
In some cases, models may require additional configuration settings, which are often included in a
`.pt` file during export. The `config_from_file` function from the `burn-import` cargo package
allows for the extraction of these configurations directly from the `.pt` file.
```rust
use std::collections::HashMap;
use burn::config::Config;
use burn_import::pytorch::config_from_file;
#[derive(Debug, Config)]
struct NetConfig {
n_head: usize,
n_layer: usize,
d_model: usize,
some_float: f64,
some_int: i32,
some_bool: bool,
some_str: String,
some_list_int: Vec<i32>,
some_list_str: Vec<String>,
some_list_float: Vec<f64>,
some_dict: HashMap<String, String>,
}
fn main() {
let path = "weights_with_config.pt";
let top_level_key = Some("my_config");
let config: NetConfig = config_from_file(path, top_level_key).unwrap();
println!("{:#?}", config);
// After extracting, it's recommended you save it as a json file.
config.save("my_config.json").unwrap();
}
```
## Troubleshooting and Advanced Features
### Key Remapping for Different Model Architectures
If your Burn model structure doesn't match the parameter names in the PyTorch file, you can remap
keys using regular expressions:
```rust
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
.with_key_remap("conv\\.(.*)", "$1");
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::init(&device).load_record(record);
```
### Debugging with Key Inspection
To help with troubleshooting import issues, you can enable debugging to print the original and
remapped keys:
```rust
let device = Default::default();
let load_args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
.with_key_remap("conv\\.(.*)", "$1")
.with_debug_print(); // Print the keys and remapped keys
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::init(&device).load_record(record);
```
Here is an example of the output:
```text
Debug information of keys and tensor shapes:
---
Original Key: conv.conv1.bias
Remapped Key: conv1.bias
Shape: [2]
Dtype: F32
---
Original Key: conv.conv1.weight
Remapped Key: conv1.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
Original Key: conv.conv2.weight
Remapped Key: conv2.weight
Shape: [2, 2, 2, 2]
Dtype: F32
---
```
### Automatic Handling of Non-Contiguous Indices
The PyTorchFileRecorder automatically handles non-contiguous indices in model layer names. For
example, if the source model contains indices with gaps:
```
"model.layers.0.weight"
"model.layers.0.bias"
"model.layers.2.weight" // Note the gap (no index 1)
"model.layers.2.bias"
"model.layers.4.weight"
"model.layers.4.bias"
```
The recorder will automatically reindex these to be contiguous while preserving their order:
```
"model.layers.0.weight"
"model.layers.0.bias"
"model.layers.1.weight" // Reindexed from 2
"model.layers.1.bias"
"model.layers.2.weight" // Reindexed from 4
"model.layers.2.bias"
```
### Partial Model Loading
You can selectively load weights into a partial model, which is useful for:
- Loading only the encoder from an encoder-decoder architecture
- Fine-tuning specific layers while initializing others randomly
- Creating hybrid models combining parts from different sources
### Specifying the Top-Level Key for state_dict
Sometimes the
[`state_dict`](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict)
is nested under a top-level key along with other metadata. In this case, you can specify the
top-level key in `LoadArgs`:
```rust
let device = Default::default();
let load_args = LoadArgs::new("tiny.en.pt".into())
.with_top_level_key("my_state_dict");
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully")
```
### Support for Enum Modules
The PyTorchFileRecorder supports models containing enum modules with new-type variants. The enum
variant is automatically selected based on the enum variant type, allowing for flexible model
architectures.
## Current Known Issues
1. [Candle's pickle does not currently unpack boolean tensors](https://github.com/tracel-ai/burn/issues/1179).

View File

@@ -1,278 +0,0 @@
# Safetensors Model
## Introduction
Burn supports importing model weights from the Safetensors format, a secure and efficient
alternative to pickle-based formats. Whether you've trained your model in PyTorch or you want to use
a pre-trained model that provides weights in Safetensors format, you can easily import them into
Burn.
This guide demonstrates the complete workflow for exporting models to Safetensors format and
importing them into Burn.
## Exporting Models to Safetensors Format
To export a PyTorch model to Safetensors format, you'll need the `safetensors` Python library. This
library provides a simple API for saving model weights in the Safetensors format.
### Example: Exporting a PyTorch Model
```python
import torch
import torch.nn as nn
from safetensors.torch import save_file
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2,2))
self.conv2 = nn.Conv2d(2, 2, (2,2), bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
if __name__ == "__main__":
# Set seed for reproducibility
torch.manual_seed(42)
# Initialize model and ensure it's on CPU
model = Net().to(torch.device("cpu"))
# Extract model weights dictionary
model_weights = model.state_dict()
# Save to Safetensors format
save_file(model_weights, "conv2d.safetensors")
```
### Verifying the Export
You can verify your exported model by viewing the `.safetensors` file in
[Netron](https://github.com/lutzroeder/netron), a neural network visualization tool. A correctly
exported file will display a flat structure of tensors, similar to a PyTorch `.pt` weights file.
## Importing Safetensors Models into Burn
Importing a Safetensors model into Burn involves two main steps:
1. Defining the model architecture in Burn
2. Loading the weights from the Safetensors file
### Step 1: Define the Model in Burn
First, you need to create a Burn model that matches the architecture of the model you exported:
```rust
use burn::{
nn::conv::{Conv2d, Conv2dConfig},
prelude::*,
};
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let conv1 = Conv2dConfig::new([2, 2], [2, 2])
.init(device);
let conv2 = Conv2dConfig::new([2, 2], [2, 2])
.with_bias(false)
.init(device);
Self { conv1, conv2 }
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(x);
self.conv2.forward(x)
}
}
```
### Step 2: Load the Weights
You have two options for loading the weights:
#### Option A: Load Dynamically at Runtime
This approach loads the Safetensors file directly at runtime, requiring the `burn-import`
dependency:
```rust
use crate::model;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::safetensors::SafetensorsFileRecorder;
type Backend = burn_ndarray::NdArray<f32>;
fn main() {
let device = Default::default();
// Load weights from Safetensors file
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load("./conv2d.safetensors".into(), &device)
.expect("Should decode state successfully");
// Initialize model and load weights
let model = model::Net::<Backend>::init(&device).load_record(record);
}
```
#### Option B: Pre-convert to Burn's Binary Format
This approach converts the Safetensors file to Burn's optimized binary format during build time,
removing the runtime dependency on `burn-import`:
```rust
// This code would go in build.rs or a separate tool
use crate::model;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use burn_import::safetensors::SafetensorsFileRecorder;
type Backend = burn_ndarray::NdArray<f32>;
fn convert_model() {
let device = Default::default();
// Load from Safetensors
let recorder = SafetensorsFileRecorder::<FullPrecisionSettings>::default();
let record = recorder
.load("./conv2d.safetensors".into(), &device)
.expect("Should decode state successfully");
// Save to Burn's binary format
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::default();
recorder
.record(record, "model.mpk".into())
.expect("Failed to save model record");
}
// In your application code
fn load_model() -> Net<Backend> {
let device = Default::default();
// Load from Burn's binary format
let record = NamedMpkFileRecorder::<FullPrecisionSettings>::default()
.load("./model.mpk".into(), &device)
.expect("Should decode state successfully");
Net::<Backend>::init(&device).load_record(record)
}
```
> **Note**: For examples of pre-converting models, see the `examples/import-model-weights` directory
> in the Burn repository.
## Advanced Configuration Options
### Framework-Specific Adapters
When importing Safetensors models, you can specify an adapter type to handle framework-specific
tensor transformations. This is crucial when importing models from different ML frameworks, as
tensor layouts and naming conventions can vary:
```rust
let device = Default::default();
// Create load arguments with framework-specific adapter
let load_args = LoadArgs::new("model.safetensors".into())
.with_adapter_type(AdapterType::PyTorch); // Default adapter
// Load with the specified adapter
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
```
#### Available Adapter Types
| Adapter Type | Description |
| --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **PyTorch** (default) | Automatically applies PyTorch-specific transformations:<br>- Transposes weights for linear layers<br>- Renames normalization parameters (weight→gamma, bias→beta) |
| **NoAdapter** | Loads tensors directly without any transformations<br>- Useful when importing from frameworks that already match Burn's tensor layout |
## Troubleshooting and Advanced Features
### Key Remapping for Different Model Architectures
If your Burn model structure doesn't match the parameter names in the Safetensors file, you can
remap keys using regular expressions:
```rust
let device = Default::default();
// Create load arguments with key remapping
let load_args = LoadArgs::new("model.safetensors".into())
// Remove "conv" prefix, e.g. "conv.conv1" -> "conv1"
.with_key_remap("conv\\.(.*)", "$1");
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
let model = Net::<Backend>::init(&device).load_record(record);
```
### Debugging with Key Inspection
To help with troubleshooting import issues, you can enable debugging to print the original and
remapped keys:
```rust
let device = Default::default();
// Enable debug printing of keys
let load_args = LoadArgs::new("model.safetensors".into())
.with_key_remap("conv\\.(.*)", "$1")
.with_debug_print(); // Print original and remapped keys
let record = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
```
### Automatic Handling of Non-Contiguous Indices
The SafetensorsFileRecorder automatically handles non-contiguous indices in model layer names. For
example, if the source model contains indices with gaps:
```
"model.layers.0.weight"
"model.layers.0.bias"
"model.layers.2.weight" // Note the gap (no index 1)
"model.layers.2.bias"
"model.layers.4.weight"
"model.layers.4.bias"
```
The recorder will automatically reindex these to be contiguous while preserving their order:
```
"model.layers.0.weight"
"model.layers.0.bias"
"model.layers.1.weight" // Reindexed from 2
"model.layers.1.bias"
"model.layers.2.weight" // Reindexed from 4
"model.layers.2.bias"
```
### Partial Model Loading
You can selectively load weights into a partial model, which is useful for:
- Loading only the encoder from an encoder-decoder architecture
- Fine-tuning specific layers while initializing others randomly
- Creating hybrid models combining parts from different sources
### Support for Enum Modules
The SafetensorsFileRecorder supports models containing enum modules with new-type variants. The enum
variant is automatically selected based on the enum variant type, allowing for flexible model
architectures.

View File

@@ -1,12 +1,11 @@
# Importing ONNX Models in Burn
# ONNX Import
## Introduction
As deep learning evolves, interoperability between frameworks becomes crucial. Burn, a modern deep
learning framework in Rust, provides robust support for importing models from other popular
frameworks. This section focuses on importing
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn,
enabling you to leverage pre-trained models in your Rust-based deep learning projects.
As deep learning evolves, interoperability between frameworks becomes crucial. Burn provides robust
support for importing [ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html)
models through the [`burn-onnx`](https://github.com/tracel-ai/burn-onnx) crate, enabling you to
leverage pre-trained models in your Rust-based deep learning projects.
## Why Import Models?
@@ -61,7 +60,7 @@ There are two simple ways to upgrade your ONNX models to the recommended opset v
Option 1: Use the provided utility script:
```
uv run --script https://raw.githubusercontent.com/tracel-ai/burn/refs/heads/main/crates/burn-onnx/onnx_opset_upgrade.py
uv run --script https://raw.githubusercontent.com/tracel-ai/burn-onnx/refs/heads/main/onnx_opset_upgrade.py
```
Option 2: Use a custom Python script:
@@ -94,19 +93,16 @@ First, add the required dependencies to your `Cargo.toml`:
```toml
[dependencies]
burn = { version = "~0.21", features = ["ndarray"] }
burn-store = { version = "~0.21", features = ["burnpack"] }
[build-dependencies]
burn-onnx = "~0.21"
```
The `burn-store` crate with the `burnpack` feature is required to load model weights at runtime.
### Step 2: Update `build.rs`
In your `build.rs` file:
```rust
```rust, ignore
use burn_onnx::ModelGen;
fn main() {
@@ -117,13 +113,13 @@ fn main() {
}
```
This generates Rust code and a `.burnpack` weights file from your ONNX model during the build process.
This generates Rust code and a `.bpk` weights file from your ONNX model during the build process.
### Step 3: Modify `mod.rs`
In your `src/model/mod.rs` file, include the generated code:
```rust
```rust, ignore
pub mod my_model {
include!(concat!(env!("OUT_DIR"), "/model/my_model.rs"));
}
@@ -133,7 +129,7 @@ pub mod my_model {
Now you can use the imported model in your code:
```rust
```rust, ignore
use burn::tensor;
use burn_ndarray::{NdArray, NdArrayDevice};
use model::my_model::Model;
@@ -158,7 +154,7 @@ fn main() {
The `ModelGen` struct provides configuration options:
```rust
```rust, ignore
ModelGen::new()
.input("path/to/model.onnx")
.out_dir("model/")
@@ -170,26 +166,27 @@ ModelGen::new()
- `input`: Path to the ONNX model file
- `out_dir`: Output directory for generated code and weights
- `development`: When enabled, generates additional debug files (`.onnx.txt`, `.graph.txt`)
- `embed_states`: When enabled, embeds model weights in the binary using `include_bytes!`.
Useful for WebAssembly or single-binary deployments. Not recommended for large models.
- `embed_states`: When enabled, embeds model weights in the binary using `include_bytes!`. Useful
for WebAssembly or single-binary deployments. Not recommended for large models.
Model weights are stored in `.burnpack` format, which provides efficient serialization and loading.
Model weights are stored in Burnpack format (`.bpk`), which provides efficient serialization and
loading.
## Loading and Using Models
You can load models in several ways:
```rust
```rust, ignore
// Load from the output directory with default device (recommended for most use cases)
// This automatically loads weights from the .burnpack file
// This automatically loads weights from the .bpk file
let model = Model::<Backend>::default();
// Create a new model instance with a specific device
// (initializes weights randomly; load weights via `load_from` afterward)
let model = Model::<Backend>::new(&device);
// Load from a specific .burnpack file
let model = Model::<Backend>::from_file("path/to/weights.burnpack", &device);
// Load from a specific .bpk file
let model = Model::<Backend>::from_file("path/to/weights.bpk", &device);
// Load from embedded weights (if embed_states was true)
let model = Model::<Backend>::from_embedded(&device);
@@ -200,7 +197,7 @@ let model = Model::<Backend>::from_embedded(&device);
Common issues and solutions:
1. **Unsupported ONNX operator**: Check the
[list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/SUPPORTED-ONNX-OPS.md).
[list of supported ONNX operators](https://github.com/tracel-ai/burn-onnx/blob/main/SUPPORTED-ONNX-OPS.md).
You may need to simplify your model or wait for support.
2. **Build errors**: Ensure your `burn-onnx` version matches your Burn version and verify the ONNX
@@ -217,13 +214,23 @@ Common issues and solutions:
## Examples and Resources
For practical examples, check out:
For practical examples, check out the
[burn-onnx examples](https://github.com/tracel-ai/burn-onnx/tree/main/examples):
1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)
1. [ONNX Inference](https://github.com/tracel-ai/burn-onnx/tree/main/examples/onnx-inference) -
MNIST inference example
2. [Image Classification Web](https://github.com/tracel-ai/burn-onnx/tree/main/examples/image-classification-web) -
SqueezeNet running in the browser via WebAssembly
3. [Raspberry Pi Pico](https://github.com/tracel-ai/burn-onnx/tree/main/examples/raspberry-pi-pico) -
Embedded deployment example
These demonstrate real-world usage of ONNX import in Burn projects.
For contributors looking to add support for new ONNX operators:
- [Development Guide](https://github.com/tracel-ai/burn-onnx/blob/main/DEVELOPMENT-GUIDE.md) -
Step-by-step guide for implementing new operators
## Conclusion
Importing ONNX models into Burn combines the vast ecosystem of pre-trained models with Burn's
@@ -231,10 +238,5 @@ performance and Rust's safety features. Following this guide, you can seamlessly
models into your Burn projects for inference, fine-tuning, or further development.
The `burn-onnx` crate is actively developed, with ongoing work to support more ONNX operators and
improve performance. Stay tuned to the Burn repository for updates!
---
> 🚨**Note**: The `burn-onnx` crate is in active development. For the most up-to-date information
> on supported ONNX operators, please refer to the
> [official documentation](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/SUPPORTED-ONNX-OPS.md).
improve performance. Visit the [burn-onnx repository](https://github.com/tracel-ai/burn-onnx) for
updates and to contribute!

View File

@@ -20,11 +20,11 @@ advanced user or a beginner. We have crafted some sections for you:
loops, fine-tuning your models to meet your specific requirements. This section empowers you to
harness Burn's flexibility to its fullest.
- [Saving & Loading Models](./saving-and-loading.md): Learn how to easily save and load your trained
models.
- [Saving & Loading Models](./saving-and-loading.md): Learn how to save and load your trained
models, including importing weights from PyTorch and SafeTensors formats.
- [Importing Models](./import): Learn how to import ONNX and PyTorch models, expanding your
compatibility with other deep learning ecosystems.
- [ONNX Import](./onnx-import.md): Learn how to import ONNX models using the
[burn-onnx](https://github.com/tracel-ai/burn-onnx) crate.
- [Models & Pre-Trained Weights](./models-and-pretrained-weights.md): Get started quickly with
ready-to-use models and pre-trained weights.

View File

@@ -56,51 +56,359 @@ Afterwards, the model can just as easily be loaded from the record saved on disk
```rust, ignore
// Load model record on the backend's default device
let record: ModelRecord<MyBackend> = NamedMpkFileRecorder::<FullPrecisionSettings>::new()
let record: ModelRecord<MyBackend> =
NamedMpkFileRecorder::<FullPrecisionSettings>::new()
.load(model_path.into(), &device)
.expect("Should be able to load the model weights from the provided file");
.expect("Could not load model weights");
// Initialize a new model with the loaded record/weights
let model = Model::init(&device).load_record(record);
```
## No Storage, No Problem!
## Model Weight Store
For applications where file storage may not be available (or desired) at runtime, you can use the
`BinBytesRecorder`.
While the Recorder API works well for basic saving and loading, `burn-store` was introduced to
address its limitations around memory efficiency and flexibility. It provides zero-copy
memory-mapped loading, cross-framework interoperability (PyTorch and SafeTensors), key remapping,
partial loading, and filtering. The `burn-store`
crate is intended to eventually replace the Recorder API, but since it was recently released, both
APIs are supported.
In the previous examples we used a `FileRecorder` based on the MessagePack format, which could be
replaced with [another file recorder](./building-blocks/record.md#recorder) of your choice. To embed
a model as part of your runtime application, first save the model to a binary file with
`BinFileRecorder`.
### Supported Formats
| Format | Extension | Description |
| --------------- | -------------- | ----------------------------------------------------------------------------------------- |
| **Burnpack** | `.bpk` | Burn's native format with fast loading, zero-copy support, and training state persistence |
| **SafeTensors** | `.safetensors` | Industry-standard format from Hugging Face for secure tensor serialization |
| **PyTorch** | `.pt`, `.pth` | Direct loading of PyTorch model weights (read-only) |
### Saving a Model
```rust, ignore
// Save model in binary format with full precision
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
use burn_store::{ModuleSnapshot, BurnpackStore};
// Save to Burnpack (recommended)
let mut store = BurnpackStore::from_file("model.bpk");
model.save_into(&mut store)?;
// Or save to SafeTensors
use burn_store::SafetensorsStore;
let mut store = SafetensorsStore::from_file("model.safetensors");
model.save_into(&mut store)?;
```
Then, in your final application, include the model and use the `BinBytesRecorder` to load it.
Embedding the model as part of your application is especially useful for smaller models but not
recommended for very large models as it would significantly increase the binary size as well as
consume a lot more memory at runtime.
### Loading a Model
```rust, ignore
// Include the model file as a reference to a byte array
static MODEL_BYTES: &[u8] = include_bytes!("path/to/model.bin");
use burn_store::{ModuleSnapshot, BurnpackStore};
// Load model binary record in full precision
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
.load(MODEL_BYTES.to_vec(), device)
.expect("Should be able to load model the model weights from bytes");
let device = Default::default();
let mut model = MyModel::init(&device);
// Load that record with the model
model.load_record(record);
// Load from Burnpack
let mut store = BurnpackStore::from_file("model.bpk");
model.load_from(&mut store)?;
```
This example assumes that the model was already created before loading the model record. If instead
you want to skip the random initialization and directly initialize the weights with the provided
record, you could adapt this like the [previous example](#initialization-from-recorded-weights).
### Loading from PyTorch
You can load weights directly from PyTorch `.pt` files:
```rust, ignore
use burn_store::{ModuleSnapshot, PytorchStore};
let mut model = MyModel::init(&device);
let mut store = PytorchStore::from_file("pytorch_model.pt");
model.load_from(&mut store)?;
```
#### Exporting from PyTorch
Save only the model weights (state_dict), not the entire model:
```python
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(2, 2, (2, 2))
self.conv2 = nn.Conv2d(2, 2, (2, 2), bias=False)
def forward(self, x):
return self.conv2(self.conv1(x))
model = Net()
torch.save(model.state_dict(), "model.pt") # Correct: save state_dict
# torch.save(model, "model.pt") # Wrong: saves entire model
```
#### Accessing Nested State Dicts
Some PyTorch checkpoints nest the state_dict under a key:
```rust, ignore
let mut store = PytorchStore::from_file("checkpoint.pt")
.with_top_level_key("state_dict");
model.load_from(&mut store)?;
```
### Loading from SafeTensors
For SafeTensors files exported from PyTorch, use the adapter for proper weight transformation:
```rust, ignore
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
let mut model = MyModel::init(&device);
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_from_adapter(PyTorchToBurnAdapter);
model.load_from(&mut store)?;
```
For SafeTensors files created by Burn, no adapter is needed:
```rust, ignore
let mut store = SafetensorsStore::from_file("model.safetensors");
model.load_from(&mut store)?;
```
#### Exporting from PyTorch to SafeTensors
```python
from safetensors.torch import save_file
model = Net()
save_file(model.state_dict(), "model.safetensors")
```
### Saving for PyTorch Compatibility
Use the adapter when saving for PyTorch consumption:
```rust, ignore
use burn_store::{BurnToPyTorchAdapter, SafetensorsStore};
let mut store = SafetensorsStore::from_file("for_pytorch.safetensors")
.with_to_adapter(BurnToPyTorchAdapter)
.skip_enum_variants(true);
model.save_into(&mut store)?;
```
### Handling Load Results
The `load_from` method returns detailed information about the loading process:
```rust, ignore
let result = model.load_from(&mut store)?;
// Print a formatted summary with suggestions
println!("{}", result);
// Or inspect individual fields
println!("Applied: {} tensors", result.applied.len());
println!("Missing: {:?}", result.missing);
println!("Errors: {:?}", result.errors);
if result.is_success() {
println!("All tensors loaded successfully");
}
```
### Adding Metadata
Burnpack and SafeTensors support custom metadata:
```rust, ignore
let mut store = BurnpackStore::from_file("model.bpk")
.metadata("version", "1.0")
.metadata("description", "My trained model")
.metadata("epochs", "100");
model.save_into(&mut store)?;
```
### Advanced Features
#### Key Remapping
Remap parameter names using regex patterns when model structures don't match:
```rust, ignore
let mut store = PytorchStore::from_file("model.pt")
// Remove prefix: "model.conv1.weight" -> "conv1.weight"
.with_key_remapping(r"^model\.", "")
// Rename: "layer1" -> "encoder.layer1"
.with_key_remapping(r"^layer", "encoder.layer");
model.load_from(&mut store)?;
```
For complex remapping:
```rust, ignore
use burn_store::KeyRemapper;
let remapper = KeyRemapper::new()
.add_pattern(r"^transformer\.h\.(\d+)\.", "transformer.layer$1.")?
.add_pattern(r"\.attn\.", ".attention.")?;
let mut store = SafetensorsStore::from_file("model.safetensors")
.remap(remapper);
```
#### Partial Loading
Load weights even when some tensors are missing:
```rust, ignore
let mut store = PytorchStore::from_file("pretrained.pt")
.allow_partial(true);
let result = model.load_from(&mut store)?;
println!("Missing (initialized randomly): {:?}", result.missing);
```
#### Filtering Tensors
Load or save only specific layers:
```rust, ignore
// Load only encoder layers
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_regex(r"^encoder\..*")
.allow_partial(true);
// Save only encoder layers
let mut store = SafetensorsStore::from_file("encoder.safetensors")
.with_regex(r"^encoder\..*");
model.save_into(&mut store)?;
// Multiple patterns (OR logic)
let mut store = SafetensorsStore::from_file("model.safetensors")
.with_regex(r"^encoder\..*") // encoder tensors
.with_regex(r".*\.bias$") // OR any bias tensors
.with_full_path("decoder.scale"); // OR specific tensor
```
#### Non-Contiguous Layer Indices
PyTorch `nn.Sequential` with mixed layers creates non-contiguous indices. `PytorchStore`
automatically remaps these:
```
PyTorch: fc.0.weight, fc.2.weight, fc.4.weight (gaps from ReLU layers)
Burn: fc.0.weight, fc.1.weight, fc.2.weight (contiguous)
```
This is enabled by default. Disable if needed:
```rust, ignore
let mut store = PytorchStore::from_file("model.pt")
.map_indices_contiguous(false);
```
#### Zero-Copy Loading
For embedded models or large files, use zero-copy loading to avoid memory copies:
```rust, ignore
// Embedded model (compile-time)
static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
let mut store = BurnpackStore::from_static(MODEL_DATA);
model.load_from(&mut store)?;
// Large file (memory-mapped)
let mut store = BurnpackStore::from_file("large_model.bpk")
.zero_copy(true);
model.load_from(&mut store)?;
```
#### Direct Tensor Access
Inspect tensors without loading into a model:
```rust, ignore
use burn_store::ModuleStore;
let mut store = PytorchStore::from_file("model.pt");
// List all tensor names
let names = store.keys()?;
// Get specific tensor
if let Some(snapshot) = store.get_snapshot("encoder.layer0.weight")? {
println!("Shape: {:?}, DType: {:?}", snapshot.shape, snapshot.dtype);
}
```
#### Model Surgery
Transfer weights between models:
```rust, ignore
use burn_store::{ModuleSnapshot, PathFilter};
// Transfer all weights
let snapshots = model1.collect(None, None, false);
model2.apply(snapshots, None, None, false);
// Transfer only encoder weights
let filter = PathFilter::new().with_regex(r"^encoder\..*");
let snapshots = model1.collect(Some(filter.clone()), None, false);
model2.apply(snapshots, Some(filter), None, false);
```
### API Reference
#### Builder Methods
| Category | Method | Description |
| ------------- | ------------------------------ | ---------------------------- |
| **Filtering** | `with_regex(pattern)` | Filter by regex pattern |
| | `with_full_path(path)` | Include specific tensor |
| | `with_predicate(fn)` | Custom filter logic |
| **Remapping** | `with_key_remapping(from, to)` | Regex-based renaming |
| | `remap(KeyRemapper)` | Complex remapping rules |
| **Adapters** | `with_from_adapter(adapter)` | Loading transformations |
| | `with_to_adapter(adapter)` | Saving transformations |
| **Config** | `allow_partial(bool)` | Continue on missing tensors |
| | `with_top_level_key(key)` | Access nested dict (PyTorch) |
| | `skip_enum_variants(bool)` | Skip enum variants in paths |
| | `map_indices_contiguous(bool)` | Remap non-contiguous indices |
| | `metadata(key, value)` | Add custom metadata |
| | `zero_copy(bool)` | Enable zero-copy loading |
#### Direct Access Methods
| Method | Description |
| --------------------- | -------------------------------- |
| `keys()` | Get ordered list of tensor names |
| `get_all_snapshots()` | Get all tensors as BTreeMap |
| `get_snapshot(name)` | Get specific tensor by name |
### Troubleshooting
#### Common Issues
1. **"Missing source values" error**: You saved the entire PyTorch model instead of the state_dict.
Re-export with `torch.save(model.state_dict(), "model.pt")`.
2. **Shape mismatch**: Your Burn model doesn't match the source architecture. Verify layer
configurations (channels, kernel sizes, bias settings).
3. **Key not found**: Parameter names don't match. Use `with_key_remapping()` or inspect keys:
```rust, ignore
let store = PytorchStore::from_file("model.pt");
println!("Available keys: {:?}", store.keys()?);
```
#### Inspecting Files
Use [Netron](https://github.com/lutzroeder/netron) to visualize `.pt` and `.safetensors` files.
For Burnpack files:
```bash
cargo run --example burnpack_inspect model.bpk
```

View File

@@ -10,7 +10,6 @@
- [Tensor](./project-architecture/tensor.md)
- [Backend](./project-architecture/backend.md)
- [Guides for Contributors](./guides/README.md)
- [ONNX to Burn: Development Guide](./guides/onnx-to-burn-conversion-tool.md)
- [Adding a New Operation to Burn](./guides/adding-a-new-operation-to-burn.md)
- [Submitting Examples to Burn](./guides/submitting-examples.md)
- [Frequently Encountered Issues](./frequently-encountered-issues/README.md)

View File

@@ -17,27 +17,3 @@ If you encounter this, swap out the `assert_eq!` in the failing test for
`tensor1.to_data().assert_approx_eq` with `3` as the second argument. The second arguments specifies
the level of precision: `3` is equivalent to a less than 10<sup>-3</sup> (0.001) difference between
the elements of the two tensors.
## Mismatched types and missing functions
```sh
error[E0308]: mismatched types --> {burn_dir}/target/debug/build/onnx-tests-fed12aaf3671687f/out/model/pow.rs:48:45 | 48 | let pow1_out1 = input1.clone().powf(input1); | ---- ^^^^^^ expected `f32`, found `Tensor<B, 4>` | | | arguments to this method are incorrect | = note: expected type `f32` found struct `Tensor<B, 4>`
note: method defined here --> {burn_dir}/burn-tensor/src/tensor/api/float.rs:65:12 | 65 | pub fn powf(self, value: f32) -> Self { | ^^^^
error[E0599]: no method named `powf_scalar` found for struct `Tensor` in the current scope --> {burn_dir}/target/debug/build/onnx-tests-fed12aaf3671687f/out/model/pow.rs:50:35 | 50 | let pow2_out1 = pow1_out1.powf_scalar(cast1_out1); | ^^^^^^^^^^^ method not found in `Tensor<B, 4>`
error[E0599]: no method named `powi` found for struct `Tensor` in the current scope --> {burn_dir}/target/debug/build/onnx-tests-fed12aaf3671687f/out/model/pow_int.rs:49:40 | 49 | let pow1_out1 = input1.clone().powi(input1); | ^^^^ method not found in `Tensor<B, 4, Int>` Some errors have detailed explanations: E0308, E0599.
For more information about an error, try `rustc --explain E0308`. error: could not compile `onnx-tests` (test "onnx_tests") due to 3 previous errors
```
If you are getting this error, you probably didn't implement your operator for the actual Tensor
struct. This issue was encountered when adding the Pow operator. The operation was added to the
`FloatTensorOps` and `IntTensorOp` traits, but not for the numeric trait (under
`burn-tensor/src/tensor/api/numeric.rs`). This, coupled with `powf` existing prior to the PR though
only for scalar values (which had been renamed, just not in the right place), led to this confusing
issue where it looked like the function was found, but the type was wrong. If that's the case, make
sure that it's implemented for the appropriate type, in this case `Float` under
[crates/burn-tensor/src/tensor/api/numeric.rs](https://github.com/tracel-ai/burn/blob/1235b06e25e39a6ee5a4ac59f7f1d3da2ddb9bc3/crates/burn-tensor/src/tensor/api/numeric.rs),
and calling the `TensorOp.foo_op` defined under
[crates/burn-tensor/src/ops/tensor.rs](https://github.com/tracel-ai/burn/blob/1235b06e25e39a6ee5a4ac59f7f1d3da2ddb9bc3/crates/burn-tensor/src/tensor/ops/tensor.rs)

View File

@@ -1,3 +1,3 @@
# Guides for Contributors
The following guides are meant to help contributors accomplish specific tasks, such as adding new operations to Burn or generating test models for `burn-import`.
The following guides are meant to help contributors accomplish specific tasks, such as adding new operations to Burn.

View File

@@ -1,480 +0,0 @@
# ONNX to Burn: Development Guide
This guide offers in-depth design insights and step-by-step procedures for developers working on the
ONNX to Burn conversion tool. This tool allows the importation of ONNX models into the Burn deep
learning framework written in Rust. It converts ONNX models to Rust source code and model
weights to `.burnpack` files.
For an introduction to ONNX import in Burn, see
[this section of the Burn book](https://burn.dev/books/burn/import/onnx-model.html).
## Design Overview
### Design Goals
- Perform best-effort conversion of ONNX models to Rust source code via Burn APIs.
- Convert ONNX model weights to Burn state files.
- Support ONNX models generated by PyTorch (ONNX Opset 16+ recommended for best compatibility).
- Produce easy-to-understand and modifiable models.
- Ensure the generated models are trainable using Burn APIs.
### Design Decisions
**Core Principles:**
- **Op/Node-Centric Design**: Built around individual operations and nodes for better scalability as
more operators are added
- **Opset-Aware Processing**: Processors accept opset parameters for flexible behavior across
different ONNX versions
- **Constants-First Approach**: All ONNX initializers are treated as constant nodes initially,
providing a uniform starting point
- **Native Type Integration**: Direct use of `burn_tensor::TensorData` and `Dtype` for efficiency,
consistency, and future mmap support
- **Multi-Phase Pipeline**: Explicit transformation phases (initialization → conversion → type
inference → post-processing → finalization) for better visibility and maintainability
- **Graph Input Name Preservation**: Sanitized ONNX names are preserved for easier development and
troubleshooting
**Separation of Concerns:**
- Limit interaction with ONNX to the Intermediate Representation (IR) stage to simplify the process
- Ensure operator behavior consistency across different OpSet versions
- Exclude any ONNX/Protobuf-specific logic from the Burn graph
- **Feature Support Validation**: The `onnx-ir` crate should extract and preserve all ONNX attributes
faithfully, even if Burn does not yet support them. Rejection of unsupported features should happen
in `burn-onnx` during code generation, not in `onnx-ir` during configuration extraction. This
allows `onnx-ir` to be reused by other projects that may have different feature support
The conversion process involves three main stages:
1. Convert ONNX model to Intermediate Representation (IR) via 5-phase pipeline.
2. Translate IR to a Burn graph.
3. Generate Rust source code from the Burn graph.
## Adding New Operators
To extend `burn-onnx` with support for new ONNX operators, follow these steps:
1. **Create PyTorch Script**: Place a PyTorch script using the new operator under
`crates/burn-onnx/onnx-tests/tests/<op>/<op>.py`. Make sure to print both input and output
tensors for end-to-end testing.
2. **Generate ONNX Model**: Run the PyTorch script to produce an ONNX model.
3. **Visualize ONNX Model**: Use [Netron](https://github.com/lutzroeder/netron) to verify the ONNX
model contains the expected operators.
4. **Generate IR and Burn Graph**: Navigate to
[crates/burn-onnx/](https://github.com/tracel-ai/burn/tree/main/crates/burn-onnx) and run:
```
cargo r -- ./onnx-tests/tests/<op>/<op>.onnx ./out
```
5. **Implement Missing Operators**: If you encounter an error stating that an operator is
unsupported, [implement it](#implementing-a-new-operator). The `./out/my-model.graph.txt` should
provide relevant information.
6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds
the Burn model in Rust code, and `my-model.burnpack` contains the model weights.
7. **Integration Test**: Include the test in the `tests/<op_name>/mod.rs` file in the
[crates/burn-onnx/onnx-tests/tests/](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/onnx-tests/tests/)
directory. Further details can be found in the
[onnx-tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/onnx-tests/README.md).
## Implementing a New Operator
To extend the capabilities of the Burn library by supporting new operations imported from ONNX
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
implementation of the `Squeeze` operation to illustrate points as needed. All file/directory paths
are relative to the root of the burn repository.
### Step 1: Node Processor Implementation in onnx-ir
The `onnx-ir` crate handles the Intermediate Representation (IR) of ONNX models using a
processor-based architecture. For each operation:
1. **Create a node module** in `crates/onnx-ir/src/node/<operation_name>.rs`. This file should
contain:
- **Configuration struct**: Define operation-specific parameters (e.g., `SqueezeConfig`)
- **Processor struct**: Implement `NodeProcessor` trait (marked as `pub(crate)`)
- The processor handles:
- **Input/output specification**: Define expected inputs and outputs via `NodeSpec`
- **Type inference**: Infer output types from inputs and configuration
- **Configuration extraction**: Extract operation parameters from ONNX attributes
- **Node construction**: Build the final `Node` enum variant with config
2. **Make the module visible** in `crates/onnx-ir/src/node/mod.rs`:
```rust
pub mod squeeze;
```
3. **Create a node struct** in your module file (e.g., `squeeze.rs`) with the standard fields:
```rust
use onnx_ir_derive::NodeBuilder;
#[derive(Debug, Clone, NodeBuilder)]
pub struct SqueezeNode {
pub name: String,
pub inputs: Vec<Argument>,
pub outputs: Vec<Argument>,
pub config: SqueezeConfig,
}
```
The `NodeBuilder` derive macro generates a test builder (e.g., `SqueezeNodeBuilder`) with methods
for constructing nodes in tests.
4. **Add to the macro invocation** in `crates/onnx-ir/src/ir/node.rs` by adding a mapping to the
`define_node_enum!` macro:
```rust
define_node_enum! {
// ... other variants
Squeeze => squeeze::SqueezeNode,
// ... more variants
}
```
This single macro invocation generates both the `NodeType` enum (for parsing) and the `Node` enum
(with tuple variants wrapping node structs) from a single source of truth.
5. **Register your processor** in `crates/onnx-ir/src/registry.rs` by adding it to the
`with_standard_processors()` function:
```rust
registry.register("Squeeze", Box::new(squeeze::SqueezeProcessor));
```
For example, the squeeze operation in `crates/onnx-ir/src/node/squeeze.rs` contains:
- A `SqueezeConfig` struct with operation parameters (axes)
- A `SqueezeProcessor` struct (marked `pub(crate)`) that implements `NodeProcessor`
- The `node_spec()` method defines input/output requirements
- The `process()` method extracts config and constructs the `Node::Squeeze` variant
### Step 2: Code Generation in burn-onnx
1. Create a new file named `<operation_name>.rs` in the `crates/burn-onnx/src/burn/node/`
directory. This file implements code generation for your operation by implementing the
`NodeCodegen` trait directly on the onnx-ir node type.
2. Implement the `NodeCodegen<PS>` trait for the onnx-ir node type. This trait defines how the node
generates Rust code during the graph compilation process:
```rust
use super::prelude::*;
impl<PS: PrecisionSettings> NodeCodegen<PS> for onnx_ir::squeeze::SqueezeNode {
fn inputs(&self) -> &[Argument] {
&self.inputs
}
fn outputs(&self) -> &[Argument] {
&self.outputs
}
fn forward(&self, scope: &mut ScopeAtPosition<'_>) -> TokenStream {
let input_arg = self.inputs.first().unwrap();
let output_arg = self.outputs.first().unwrap();
// Use scope.arg() to handle Tensor/Scalar/Shape arguments automatically
let input = scope.arg(input_arg);
let output = arg_to_ident(output_arg);
// Access node configuration
match &self.config.axes {
Some(axes) => {
let axes_values: Vec<_> = axes.iter().map(|&i| {
proc_macro2::Literal::i64_suffixed(i)
}).collect();
quote! {
let #output = #input.squeeze_dims(&[#(#axes_values),*]);
}
}
None => {
// Get output rank from type inference
let output_rank = match &output_arg.ty {
ArgType::Tensor(t) => t.rank,
_ => panic!("Expected tensor output"),
};
quote! {
let #output = #input.squeeze::<#output_rank>();
}
}
}
}
}
```
Key methods to implement:
- `inputs(&self)` - Returns references to input arguments (usually just `&self.inputs`)
- `outputs(&self)` - Returns references to output arguments (usually just `&self.outputs`)
- `forward(&self, scope)` - Generates Rust code for the operation using the `quote!` macro
- `field(&self)` - (Optional) Declares module fields for parameters like weights
- `collect_snapshots(&self, field_name)` - (Optional) Collects tensor snapshots for burnpack serialization
3. Use helper utilities from `argument_helpers.rs`:
- `scope.arg(argument)` - Automatically handles Tensor/Scalar/Shape with proper cloning
- `arg_to_ident(argument)` - Converts argument to identifier for code generation
4. Add unit tests using snapshot testing to verify the generated code. These tests typically use the
`insta` crate and test helper functions to validate the generated code:
```rust
#[cfg(test)]
mod tests {
use super::super::test_helpers::*;
use insta::assert_snapshot;
use onnx_ir::squeeze::SqueezeNodeBuilder;
#[test]
fn test_squeeze_forward() {
let node = SqueezeNodeBuilder::new("squeeze1")
.input_tensor("input", 3, DType::F32)
.output_tensor("output", 2, DType::F32)
.axes(vec![1])
.build();
let code = codegen_forward_default(&node);
assert_snapshot!(code, @"let output = input.squeeze_dims(&[1i64]);");
}
}
```
### Step 3: Register in Module System
Add the module declaration to `crates/burn-onnx/src/burn/node/mod.rs`:
```rust
// ... other node modules
pub(crate) mod squeeze;
// ... more node modules
```
The modules are automatically made visible through re-exports in the same file.
### Step 4: Register in Code Generation Dispatch
Add your operation to the dispatch macro in `crates/burn-onnx/src/burn/node_codegen.rs`. The
`impl_node_codegen_dispatch!` macro generates the trait implementation that dispatches to your
node-specific code.
Add the node variant name (as defined in `onnx-ir`'s `Node` enum) to the macro invocation:
```rust
impl_node_codegen_dispatch! {
// ... other operations
Squeeze, // Add your operation here (matches Node::Squeeze variant)
// ... more operations
}
```
The macro automatically generates:
- Dispatch implementation for `NodeCodegen<PS>` on `onnx_ir::Node`
- All required trait methods (`inputs`, `outputs`, `forward`, `field`, etc.)
- Pattern matching to route to your node-specific implementation
### Step 5: Processor Implementation
The `NodeProcessor` trait defines how operations are processed in onnx-ir. Each processor must
implement:
1. **Associated type**: `type Config` - Define your configuration struct (use `()` if no config)
2. **`infer_types()`** - Infer output types from inputs and config (required)
3. **`build_node()`** - Construct the node struct and wrap it in the `Node` enum variant (required)
4. **`extract_config()`** - Extract config from attributes/inputs (override if Config != `()`)
5. **`spec()`** - Define opset and input/output requirements (optional)
6. **`lift_constants()`** - Request constant lifting for inputs (optional)
Example `build_node()` implementation:
```rust
fn build_node(&self, builder: RawNode, opset: usize) -> Node {
let config = self.extract_config(&builder, opset).expect("Config extraction failed");
Node::Squeeze(SqueezeNode {
name: builder.name,
inputs: builder.inputs,
outputs: builder.outputs,
config,
})
}
```
Note: `RawNode` is the intermediate node representation used during processing. The `build_node()`
method converts it into the final typed `Node` enum variant.
For complete examples, see existing processors:
- **Simple operation**: `crates/onnx-ir/src/node/softmax.rs`
- **With constant inputs**: `crates/onnx-ir/src/node/squeeze.rs`
- **Complex operation**: `crates/onnx-ir/src/node/conv2d.rs`
See [NodeProcessor Trait](#nodeprocessor-trait) for the complete trait definition.
### Step 6: Add Newly Supported Op!
As a reward, add an extra check to `crates/burn-onnx/SUPPORTED-ONNX-OPS.md`!
### Constant Lifting
The onnx-ir pipeline automatically handles constant lifting during the post-processing phase.
"Lifting" constants means making constant values directly accessible on node inputs via
`Argument::value()`, instead of requiring a separate graph traversal to find a Constant node.
**When to use**: If your operation takes constant inputs (e.g., weights in Conv1d, shape tensors in
Reshape, axes in Squeeze), access them via `node.inputs[N].value()` in your `extract_config()`
method. See the [Configuration Extraction example](#example-configuration-extraction) in Step 5.
**Optional optimization**: Implement `lift_constants()` to explicitly request constant lifting for
specific inputs before `extract_config()` is called. The pipeline handles this automatically during
post-processing.
## Architecture Overview
### ONNX-IR Pipeline
The `onnx-ir` crate converts ONNX models to an Intermediate Representation through a 5-phase
pipeline:
#### Phase 1: Initialization
- Creates `GraphState` from ONNX proto structures
- **Constants-first approach**: Converts all ONNX initializers into Constant nodes, providing a
uniform starting point for processing
- Sets up the value store for tensor data using `burn_tensor::TensorData`
- Preserves sanitized graph input names for debugging
#### Phase 2: Node Conversion
- Converts ONNX nodes to IR nodes using registered processors
- Creates `RawNode` instances from ONNX proto nodes (intermediate representation)
- Processors extract configuration and construct typed `Node` enum variants
- Handles constant nodes specially (extracting values from attributes into tensor store)
- Each processor is responsible for its own type inference and node construction
#### Phase 3: Type Inference
- Type inference happens within each processor's `process()` method during Phase 2
- Processors infer output types based on input types and configuration
- Multi-pass processing handles dependencies between nodes
- The pipeline may need multiple iterations for complex type dependencies (e.g., control flow)
#### Phase 4: Post-processing
- Lifts constants: Makes constant values accessible on downstream node inputs
- Eliminates Identity nodes: Removes no-op nodes and rewires the graph
- Re-runs constant lifting after Identity elimination
#### Phase 5: Finalization
- Removes unreferenced constant nodes
- Constructs the final `OnnxGraph` with inputs, outputs, and nodes
### NodeProcessor Trait
The `NodeProcessor` trait (defined in `crates/onnx-ir/src/processor.rs`) is the core abstraction for
handling ONNX operations. Each processor implements:
**Required:**
- `type Config` - Associated type for configuration (use `()` if no config needed)
- `infer_types()` - Infer output types from inputs and configuration
- `build_node()` - Construct the final `Node` enum variant
**Optional (have defaults):**
- `spec()` - Define opset requirements and input/output count validation (`NodeSpec`, `InputSpec`,
`OutputSpec`)
- `extract_config()` - Extract configuration from attributes/inputs (default returns
`Default::default()`)
- `lift_constants()` - Request constant lifting for specific inputs (default does nothing)
- `input_preferences()` - Declare preferred input types from producers (default returns `None`)
Design principles: Each processor is self-contained, handling type inference, config extraction, and
node construction. Processors return strongly-typed `Node` enum variants, ensuring type safety
throughout the pipeline.
## Testing
When implementing a new operator, there are several levels of testing to consider:
### Unit Testing
- **Processor Methods**: Write unit tests in `crates/onnx-ir/src/node/<operation_name>.rs` to
verify:
- `extract_config()` - Correctly extracts configuration from attributes and inputs
- `infer_types()` - Correctly infers output types (element type, rank, static shapes)
- `build_node()` - Constructs correct `Node` enum variant
- `spec()` - Defines correct opset and input/output requirements
- Error handling for invalid inputs or configurations
See existing tests in `crates/onnx-ir/src/node/squeeze.rs` for examples.
- **Code Generation**: Test the burn-onnx Node implementation to verify correct Rust code
generation. Each node file typically includes unit tests using `assert_tokens()` to validate
generated code against expected output.
### Integration Testing
- **Test Path**: Write integration tests in `crates/burn-onnx/onnx-tests/tests/<op_name>/mod.rs` where `<op_name>` is the name of the new operator.
- **What to Test**:
- Create ONNX models that use your operator and test the end-to-end conversion process
- Ensure the generated Rust code compiles
- Test with realistic ONNX models that use your operator in conjunction with others
- Include models that test edge cases (e.g., different input shapes, parameter combinations)
- Verify that inputs and outputs match between the original ONNX model and the converted Burn model
- Further details can be found in the
[onnx-tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/onnx-tests/README.md).
Testing the processor implementation is particularly important as it directly affects the
correctness of the conversion process. Incorrect type inference can lead to mismatched tensor shapes
or wrong element types, while incorrect configuration extraction can cause runtime errors or produce
incorrect results.
## Node Enum Architecture
The ONNX-IR uses an enum-based node representation where each ONNX operation is a variant of the
`Node` enum (defined in `crates/onnx-ir/src/ir/node.rs`). Each variant wraps an operation-specific
node struct (e.g., `SoftmaxNode`, `Conv2dNode`) that contains `name`, `inputs`, `outputs`, and
optionally a `config` field.
The `define_node_enum!` macro generates both enums from a single source using the syntax
`VariantName => module::NodeStructType`:
```rust
define_node_enum! {
Softmax => softmax::SoftmaxNode,
Conv2d => conv2d::Conv2dNode,
Squeeze => squeeze::SqueezeNode,
// ... 200+ more variants
}
```
This macro generates:
1. **`NodeType` enum**: Simple unit variants for ONNX parsing (`Softmax`, `Conv2d`, etc.)
2. **`Node` enum**: Tuple variants wrapping node structs (`Softmax(SoftmaxNode)`,
`Conv2d(Conv2dNode)`, etc.)
3. **Accessor methods**: `name()`, `inputs()`, `outputs()` automatically generated for the `Node`
enum
This design provides:
- **Type safety**: Each operation has its own struct type
- **Trait implementations**: Operations can implement specific traits on their node structs
- **Single source of truth**: Both enums are guaranteed to stay in sync
- **Pattern matching**: Easy to match on specific operations and access their configuration
## Resources
1. [PyTorch to ONNX](https://pytorch.org/docs/stable/onnx.html)
2. [ONNX to PyTorch](https://github.com/ENOT-AutoDL/onnx2torch)
3. [ONNX Introduction](https://onnx.ai/onnx/intro/)
4. [ONNX Operators](https://onnx.ai/onnx/operators/index.html)
5. [ONNX Protos](https://onnx.ai/onnx/api/classes.html)
6. [ONNX Optimizer](https://github.com/onnx/optimizer)
7. [Netron](https://github.com/lutzroeder/netron)

View File

@@ -1,48 +0,0 @@
[package]
authors = [
"Dilshod Tadjibaev (@antimora)",
"Nathaniel Simard (@nathanielsimard)",
]
description = "Library for importing datamodels into the Burn framework"
edition.workspace = true
license.workspace = true
name = "burn-import"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-import"
documentation = "https://docs.rs/burn-import"
version.workspace = true
[lints]
workspace = true
[features]
default = ["pytorch", "safetensors", "onnx", "burn-onnx?/default"]
onnx = ["burn-onnx"]
onnx-mmap = ["burn-onnx", "burn-onnx/mmap"]
pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]
safetensors = [
"burn/record-item-custom-serde",
"thiserror",
"zip",
"candle-core",
]
[dependencies]
burn = { path = "../burn", version = "=0.21.0", default-features = false, features = [
"std",
] }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0", default-features = false, optional = true }
burn-store = { path = "../burn-store", version = "=0.21.0", default-features = false, features = ["std", "pytorch", "burnpack"] }
burn-onnx = { path = "../burn-onnx", version = "=0.21.0", default-features = false, optional = true }
candle-core = { workspace = true, optional = true }
derive-new = { workspace = true }
regex = { workspace = true, features = ["default"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true, features = ["std"] }
thiserror = { workspace = true, optional = true }
zip = { workspace = true, optional = true }
[package.metadata.docs.rs]
features = ["default"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@@ -1 +0,0 @@
../../LICENSE-MIT

View File

@@ -1,24 +0,0 @@
# Burn Import
The `burn-import` crate enables seamless integration of pre-trained models from popular machine
learning frameworks into the Burn ecosystem. This functionality allows you to leverage existing
models while benefiting from Burn's performance optimizations and native Rust integration.
## Supported Import Formats
Burn currently supports three primary model import formats, each serving different use cases:
| Format | Description | Use Case |
| ----------------------------------------------------------------------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------ |
| [**ONNX** (Guide)](https://burn.dev/books/burn/import/onnx-model.html) | Open Neural Network Exchange format | Direct import of complete model architectures and weights from any framework that supports ONNX export |
| [**PyTorch** (Guide)](https://burn.dev/books/burn/import/pytorch-model.html) | PyTorch weights (.pt, .pth) | Loading weights from PyTorch models into a matching Burn architecture |
| [**Safetensors** (Guide)](https://burn.dev/books/burn/import/safetensors-model.html) | Hugging Face's model serialization format | Loading a model's tensor weights into a matching Burn architecture |
## ONNX Contributor Resources
- [ONNX to Burn conversion guide](https://burn.dev/books/contributor/guides/onnx-to-burn-conversion-tool.html) -
Instructions for adding support for additional ONNX operators
- [ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/onnx-tests/README.md) -
Testing procedures for ONNX operators
- [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md) -
Complete list of currently supported ONNX operators

View File

@@ -1,32 +0,0 @@
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
#[derive(Module, Debug)]
#[allow(unused)]
pub struct Net<B: Backend> {
do_not_exist_in_pt: Conv2d<B>,
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
use super::*;
#[test]
#[should_panic(
expected = "Missing source values for the 'do_not_exist_in_pt' field of type 'Conv2dRecordItem'. Please verify the source data and ensure the field name is correct"
)]
fn should_fail_if_struct_field_is_missing() {
let device = Default::default();
let _record: NetRecord<TestBackend> =
PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(
"tests/missing_module_field/missing_module_field.pt".into(),
&device,
)
.expect("Should decode state successfully");
}
}

View File

@@ -1,39 +0,0 @@
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};
#[derive(Module, Debug)]
#[allow(unused)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
}
#[cfg(test)]
mod tests {
use crate::backend::TestBackend;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
use super::*;
#[test]
#[should_panic]
fn should_fail_if_not_found() {
let device = Default::default();
let _record: NetRecord<TestBackend> =
PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load("tests/top_level_key/top_level_key.pt".into(), &device)
.expect("Should decode state successfully");
}
#[test]
fn should_load() {
let device = Default::default();
let load_args = LoadArgs::new("tests/top_level_key/top_level_key.pt".into())
.with_top_level_key("my_state_dict");
let _record: NetRecord<TestBackend> =
PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(load_args, &device)
.expect("Should decode state successfully");
}
}

View File

@@ -1,104 +0,0 @@
use burn::{
module::Param,
record::{PrecisionSettings, Record},
tensor::{Tensor, backend::Backend},
};
use burn::record::serde::{
adapter::{BurnModuleAdapter, DefaultAdapter},
data::NestedValue,
ser::Serializer,
};
use serde::Serialize;
/// A PyTorch adapter for the Burn module used during deserialization.
///
/// Not all Burn module correspond to a PyTorch module. Therefore,
/// we need to adapt the Burn module to a PyTorch module. We implement
/// only those that differ.
pub struct PyTorchAdapter<PS: PrecisionSettings, B: Backend> {
_precision_settings: std::marker::PhantomData<(PS, B)>,
}
impl<PS: PrecisionSettings, B: Backend> BurnModuleAdapter for PyTorchAdapter<PS, B> {
fn adapt_linear(data: NestedValue) -> NestedValue {
// Get the current module in the form of map.
let mut map = data.as_map().expect("Failed to get map from NestedValue");
// Get/remove the weight parameter.
let weight = map
.remove("weight")
.expect("Failed to find 'weight' key in map");
// Convert the weight parameter to a tensor (use default device, since it's quick operation).
let weight: Param<Tensor<B, 2>> = weight
.try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default())
.expect("Failed to deserialize weight");
// Do not capture transpose op when using autodiff backend
let weight = weight.set_require_grad(false);
// Transpose the weight tensor.
let weight_transposed = Param::from_tensor(weight.val().transpose());
// Insert the transposed weight tensor back into the map.
map.insert(
"weight".to_owned(),
serialize::<PS, _, 2>(weight_transposed),
);
// Return the modified map.
NestedValue::Map(map)
}
fn adapt_group_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
fn adapt_batch_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
fn adapt_layer_norm(data: NestedValue) -> NestedValue {
rename_weight_bias(data)
}
}
/// Helper function to serialize a param tensor.
fn serialize<PS, B, const D: usize>(val: Param<Tensor<B, D>>) -> NestedValue
where
B: Backend,
PS: PrecisionSettings,
{
let serializer = Serializer::new();
val.into_item::<PS>()
.serialize(serializer)
.expect("Failed to serialize the item")
}
/// Helper function to rename the weight and bias parameters to gamma and beta.
///
/// This is needed because PyTorch uses different names for the normalizer parameter
/// than Burn. Burn uses gamma and beta, while PyTorch uses weight and bias.
fn rename_weight_bias(data: NestedValue) -> NestedValue {
// Get the current module in the form of map.
let mut map = data.as_map().expect("Failed to get map from NestedValue");
// Rename the weight parameter to gamma.
let weight = map
.remove("weight")
.expect("Failed to find 'weight' key in map");
map.insert("gamma".to_owned(), weight);
// Rename the bias parameter to beta.
let bias = map
.remove("bias")
.expect("Failed to find 'bias' key in map");
map.insert("beta".to_owned(), bias);
// Return the modified map.
NestedValue::Map(map)
}

View File

@@ -1,156 +0,0 @@
use core::ops::Deref;
use std::collections::HashMap;
use burn::record::serde::{
data::{NestedValue, Serializable},
error,
ser::Serializer,
};
use burn::{
module::ParamId,
record::PrecisionSettings,
tensor::{Element, ElementConversion, TensorData, bf16, f16},
};
use candle_core::WithDType;
use serde::Serialize;
use burn::record::RecorderError;
use zip::result::ZipError;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Serde error: {0}")]
Serde(#[from] error::Error),
#[error("Candle Tensor error: {0}")]
CandleTensor(#[from] candle_core::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Zip error: {0}")]
Zip(#[from] ZipError),
// Add other kinds of errors as needed
#[error("other error: {0}")]
Other(String),
}
// Implement From trait for Error to RecorderError
impl From<Error> for RecorderError {
fn from(error: Error) -> Self {
RecorderError::DeserializeError(error.to_string())
}
}
/// Serializes a candle tensor.
///
/// Tensors are wrapped in a `Param` struct (learnable parameters) and serialized as a `TensorData` struct.
///
/// Values are serialized as `FloatElem` or `IntElem` depending on the precision settings.
impl Serializable for CandleTensor {
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, error::Error>
where
PS: PrecisionSettings,
{
let shape = self.shape().clone().into_dims();
let flatten = CandleTensor(self.flatten_all().expect("Failed to flatten the tensor"));
let param_id = ParamId::new();
match self.dtype() {
candle_core::DType::U8 => {
serialize_data::<u8, PS::IntElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::U32 => {
serialize_data::<u32, PS::IntElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::I64 => {
serialize_data::<i64, PS::IntElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::BF16 => {
serialize_data::<bf16, PS::FloatElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::F16 => {
serialize_data::<f16, PS::FloatElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::F32 => {
serialize_data::<f32, PS::FloatElem>(flatten, shape, param_id, serializer)
}
candle_core::DType::F64 => {
serialize_data::<f64, PS::FloatElem>(flatten, shape, param_id, serializer)
}
}
}
}
/// Helper function to serialize a candle tensor data.
fn serialize_data<T, E>(
tensor: CandleTensor,
shape: Vec<usize>,
param_id: ParamId,
serializer: Serializer,
) -> Result<NestedValue, error::Error>
where
E: Element + Serialize,
T: WithDType + ElementConversion,
{
let data: Vec<E> = tensor
.to_vec1::<T>()
.map_err(|err| error::Error::Other(format!("Candle to vec1 error: {err}")))?
.into_iter()
.map(ElementConversion::elem)
.collect();
let data = TensorData::new(data, shape.clone());
let (dtype, bytes) = (data.dtype, data.into_bytes());
// Manually serialize the tensor instead of using the `ParamSerde` struct, such as:
// ParamSerde::new(param_id, TensorData::new(data, shape)).serialize(serializer)
// Because serializer copies individual elements of TensorData `value` into a new Vec<u8>,
// which is not necessary and inefficient.
let mut tensor_data: HashMap<String, NestedValue> = HashMap::new();
tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes));
tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?);
tensor_data.insert("dtype".into(), dtype.serialize(serializer)?);
let mut param: HashMap<String, NestedValue> = HashMap::new();
param.insert("id".into(), NestedValue::String(param_id.serialize()));
param.insert("param".into(), NestedValue::Map(tensor_data));
Ok(NestedValue::Map(param))
}
/// New type struct for Candle tensors because we need to implement the `Serializable` trait for it.
pub struct CandleTensor(pub candle_core::Tensor);
impl Deref for CandleTensor {
type Target = candle_core::Tensor;
fn deref(&self) -> &Self::Target {
&self.0
}
}
pub fn print_debug_info(
tensors: &HashMap<String, CandleTensor>,
remapped_keys: Vec<(String, String)>,
) {
let mut remapped_keys = remapped_keys;
remapped_keys.sort();
println!("Debug information of keys and tensor shapes:\n---");
for (new_key, old_key) in remapped_keys {
if old_key != new_key {
println!("Original Key: {old_key}");
println!("Remapped Key: {new_key}");
} else {
println!("Key: {new_key}");
}
let shape = tensors[&new_key].shape();
let dtype = tensors[&new_key].dtype();
println!("Shape: {shape:?}");
println!("Dtype: {dtype:?}");
println!("---");
}
}

View File

@@ -1,4 +0,0 @@
pub mod adapter;
#[cfg(feature = "safetensors")]
pub mod candle;
pub mod tensor_snapshot;

View File

@@ -1,81 +0,0 @@
//! TensorSnapshot support for burn-import.
use burn::record::PrecisionSettings;
use burn::record::serde::{
data::{NestedValue, Serializable},
error,
ser::Serializer,
};
use burn_store::TensorSnapshot;
use serde::Serialize;
use std::collections::HashMap;
use std::ops::Deref;
/// Wrapper for TensorSnapshot to implement Serializable
pub struct TensorSnapshotWrapper(pub TensorSnapshot);
impl Deref for TensorSnapshotWrapper {
type Target = TensorSnapshot;
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// Serializes a TensorSnapshot.
///
/// Tensors are wrapped in a `Param` struct (learnable parameters) and serialized as a `TensorData` struct.
impl Serializable for TensorSnapshotWrapper {
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, error::Error>
where
PS: PrecisionSettings,
{
// Get the tensor data
let data = self
.0
.to_data()
.map_err(|e| error::Error::Other(format!("Failed to get tensor data: {:?}", e)))?;
let shape = data.shape.clone();
let dtype = data.dtype;
let bytes = data.into_bytes();
// Create the tensor data structure
let mut tensor_data: HashMap<String, NestedValue> = HashMap::new();
tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes));
tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?);
tensor_data.insert("dtype".into(), dtype.serialize(serializer)?);
// Create the param structure
let param_id = self.0.tensor_id.unwrap_or_default();
let mut param: HashMap<String, NestedValue> = HashMap::new();
param.insert("id".into(), NestedValue::String(param_id.serialize()));
param.insert("param".into(), NestedValue::Map(tensor_data));
Ok(NestedValue::Map(param))
}
}
/// Print debug information about tensors
pub fn print_debug_info(
tensors: &HashMap<String, TensorSnapshotWrapper>,
remapped_keys: Vec<(String, String)>,
) {
let mut remapped_keys = remapped_keys;
remapped_keys.sort();
println!("Debug information of keys and tensor shapes:\n---");
for (new_key, old_key) in remapped_keys {
if old_key != new_key {
println!("Original Key: {old_key}");
println!("Remapped Key: {new_key}");
} else {
println!("Key: {new_key}");
}
let snapshot = &tensors[&new_key].0;
let shape = &snapshot.shape;
let dtype = &snapshot.dtype;
println!("Shape: {shape:?}");
println!("Dtype: {dtype:?}");
println!("---");
}
}

View File

@@ -1,46 +0,0 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! `burn-import` is a crate designed to simplify the process of importing models trained in other
//! machine learning frameworks into the Burn framework.
#[cfg(any(feature = "pytorch", feature = "safetensors"))]
#[macro_use]
extern crate derive_new;
/// The onnx module.
#[cfg(feature = "onnx")]
#[deprecated(
since = "0.21.0",
note = "ONNX import was moved to `burn-onnx`. Use that crate instead."
)]
pub mod onnx {
#[deprecated(
since = "0.21.0",
note = "ONNX import was moved to `burn-onnx`. Use that crate instead."
)]
#[allow(missing_docs)]
pub type ModelGen = burn_onnx::ModelGen;
}
/// The module for generating the burn code.
#[cfg(feature = "onnx")]
#[deprecated(
since = "0.21.0",
note = "ONNX import was moved to `burn-onnx`. Use that crate instead."
)]
pub mod burn {
pub use burn_onnx::burn::*;
}
/// The PyTorch module for recorder.
#[cfg(feature = "pytorch")]
pub mod pytorch;
/// The Safetensors module for recorder.
#[cfg(feature = "safetensors")]
pub mod safetensors;
// Enabled when the `pytorch` or `safetensors` feature is enabled.
#[cfg(any(feature = "pytorch", feature = "safetensors"))]
mod common;

View File

@@ -1,51 +0,0 @@
use std::path::Path;
use burn_store::pytorch::PytorchReader;
use serde::de::DeserializeOwned;
use super::reader::Error;
/// Loads configuration data from a PyTorch `.pth` file.
///
/// This function reads specific configuration or metadata stored in PyTorch checkpoint files.
/// It's particularly useful for extracting model configurations that might be saved alongside
/// the model weights.
///
/// # Arguments
///
/// * `file` - Path to the PyTorch `.pth` file.
/// * `key` - Optional key to filter specific data within the pickle file.
/// If `None`, the entire content is deserialized.
///
/// # Type Parameters
///
/// * `D` - The target type to deserialize into. Must implement `DeserializeOwned`.
///
/// # Returns
///
/// A `Result` containing the deserialized configuration data, or an `Error` if
/// reading or deserialization fails.
///
/// # Examples
///
/// ```ignore
/// use burn_import::pytorch::config::load_config_from_file;
/// use serde::Deserialize;
///
/// #[derive(Debug, Deserialize)]
/// struct ModelConfig {
/// hidden_size: usize,
/// num_layers: usize,
/// // ... other configuration fields
/// }
///
/// let config: ModelConfig = load_config_from_file("model.pth", Some("config"))?;
/// ```
pub fn load_config_from_file<D, P>(file: P, key: Option<&str>) -> Result<D, Error>
where
D: DeserializeOwned,
P: AsRef<Path>,
{
// Use burn-store's PytorchReader to load and deserialize config
PytorchReader::load_config(file, key).map_err(Error::Store)
}

View File

@@ -1,5 +0,0 @@
mod config;
mod reader;
mod recorder;
pub use config::load_config_from_file;
pub use recorder::{LoadArgs, PyTorchFileRecorder};

View File

@@ -1,110 +0,0 @@
use std::collections::HashMap;
use std::path::Path;
use crate::common::{
adapter::PyTorchAdapter,
tensor_snapshot::{TensorSnapshotWrapper, print_debug_info},
};
use burn::record::PrecisionSettings;
use burn::{
record::serde::{
data::{remap, unflatten},
de::Deserializer,
},
tensor::backend::Backend,
};
use burn_store::pytorch::PytorchReader;
use regex::Regex;
use serde::de::DeserializeOwned;
/// Error type for PyTorch file operations
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Store error: {0}")]
Store(#[from] burn_store::pytorch::PytorchError),
#[error("Serde error: {0}")]
Serde(#[from] burn::record::serde::error::Error),
#[error("Other error: {0}")]
Other(String),
}
/// Deserializes tensor data from a PyTorch file (`.pt` or `.pth`) into a Burn record.
///
/// This function reads tensors from a pickle file using burn-store's PyTorch reader,
/// optionally remaps their keys, and then deserializes them into the specified record type `D`.
///
/// # Arguments
///
/// * `path` - The path to the PyTorch file to load.
/// * `key_remap` - A list of rules for renaming tensor keys. Each rule is a tuple
/// containing a regular expression to match the original key and a replacement string.
/// * `top_level_key` - An optional key within the pickle file if the tensors are nested
/// under a specific dictionary key (e.g., "state_dict").
/// * `debug` - If `true`, prints information about the loaded tensors and remapped keys.
///
/// # Type Parameters
///
/// * `PS` - The precision settings to use during deserialization.
/// * `D` - The target Burn record type to deserialize into.
/// * `B` - The backend to use for tensor operations (primarily for type context).
///
/// # Returns
///
/// A `Result` containing the deserialized record `D` on success, or an `Error` if
/// reading, remapping, or deserialization fails.
pub fn from_file<PS, D, B>(
path: &Path,
key_remap: Vec<(Regex, String)>,
top_level_key: Option<&str>,
debug: bool,
) -> Result<D, Error>
where
D: DeserializeOwned,
PS: PrecisionSettings,
B: Backend,
{
// Use burn-store's PyTorch reader to load tensors
let reader = if let Some(key) = top_level_key {
PytorchReader::with_top_level_key(path, key)?
} else {
PytorchReader::new(path)?
};
// Get the tensors as TensorSnapshots and wrap them
let tensors: HashMap<String, TensorSnapshotWrapper> = reader
.into_tensors()
.into_iter()
.map(|(key, snapshot)| (key, TensorSnapshotWrapper(snapshot)))
.collect();
// Remap the tensor keys based on the provided rules
let (tensors, remapped_keys) = remap(tensors, key_remap);
// Print debug information if enabled
if debug {
print_debug_info(&tensors, remapped_keys);
}
// Convert the flat map of tensors into a nested data structure suitable for deserialization
let nested_value = unflatten::<PS, _>(tensors)?;
// Create a deserializer using the PyTorch adapter and the nested tensor data
let deserializer = Deserializer::<PyTorchAdapter<PS, B>>::new(nested_value, true);
// Deserialize the nested data structure into the target record type
let value = D::deserialize(deserializer)?;
Ok(value)
}
// Re-export burn-store's PyTorch reader types for convenience
// Implement conversion to RecorderError for compatibility with the Recorder trait
impl From<Error> for burn::record::RecorderError {
fn from(error: Error) -> Self {
burn::record::RecorderError::DeserializeError(error.to_string())
}
}

View File

@@ -1,173 +0,0 @@
use core::marker::PhantomData;
use std::path::PathBuf;
use burn::{
record::{PrecisionSettings, Record, Recorder, RecorderError},
tensor::backend::Backend,
};
use regex::Regex;
use serde::{Serialize, de::DeserializeOwned};
use super::reader::from_file;
/// Recorder for loading PyTorch (`.pt`) files into Burn modules.
///
/// Load arguments ([`LoadArgs`]) can be used to specify the file path and
/// remap parameter keys during loading.
#[derive(new, Debug, Default, Clone)]
pub struct PyTorchFileRecorder<PS: PrecisionSettings> {
_settings: PhantomData<PS>,
}
impl<PS: PrecisionSettings, B: Backend> Recorder<B> for PyTorchFileRecorder<PS> {
type Settings = PS;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = LoadArgs;
fn save_item<I: Serialize>(
&self,
_item: I,
_file: Self::RecordArgs,
) -> Result<(), RecorderError> {
unimplemented!("Save operations are not supported by PyTorchFileRecorder.")
}
fn load_item<I: DeserializeOwned>(
&self,
_file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
unimplemented!("load_item is not implemented for PyTorchFileRecorder; use load instead.")
}
fn load<R: Record<B>>(
&self,
args: Self::LoadArgs,
device: &B::Device,
) -> Result<R, RecorderError> {
let item = from_file::<PS, R::Item<Self::Settings>, B>(
&args.file,
args.key_remap,
args.top_level_key.as_deref(), // Convert Option<String> to Option<&str>
args.debug,
)?;
Ok(R::from_item(item, device))
}
}
/// Arguments for loading PyTorch model weights.
///
/// # Notes
///
/// Parameter keys within a PyTorch file (`.pt` extension) can be inspected using
/// tools like [Netron](https://github.com/lutzroeder/netron).
///
/// # Examples
///
/// ```rust,ignore
/// use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
/// use burn::record::{FullPrecisionSettings, Recorder};
///
/// // Create load arguments, specifying the file and a key remapping rule.
/// let args = LoadArgs::new("tests/key_remap/key_remap.pt".into())
/// // Remove "conv." prefix, e.g., "conv.weight" -> "weight"
/// .with_key_remap("conv\\.(.*)", "$1");
///
/// // Load the record using the default recorder.
/// let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
/// .load(args, &burn::backend::NdArray::default().device()) // Provide a device
/// .expect("Failed to decode state from file"); // Example assertion
/// ```
#[derive(Debug, Clone)]
pub struct LoadArgs {
/// The path to the PyTorch file (`.pt`).
pub file: PathBuf,
/// A list of key remapping rules applied to the state dictionary keys.
/// Each rule consists of a regular expression and a replacement string.
/// See [regex::Regex::replace](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace)
/// for more details.
pub key_remap: Vec<(Regex, String)>,
/// Optional top-level key under which the state dictionary is nested within the file.
/// If `None`, the root object is assumed to be the state dictionary.
pub top_level_key: Option<String>,
/// If `true`, prints debug information during the loading process.
pub debug: bool,
}
impl LoadArgs {
/// Creates new load arguments with the given file path.
///
/// # Arguments
///
/// * `file` - The path to the PyTorch file to load.
pub fn new(file: PathBuf) -> Self {
Self {
file,
key_remap: Vec::new(),
top_level_key: None,
debug: false,
}
}
/// Adds a key remapping rule.
///
/// Keys from the PyTorch state dictionary are modified if they match the pattern.
///
/// # Arguments
///
/// * `pattern` - The regular expression pattern to match against state dictionary keys.
/// * `replacement` - The replacement string. Capture groups can be used (e.g., `$1`).
///
/// # Panics
///
/// Panics if the provided `pattern` is an invalid regular expression.
///
/// See the [regex crate documentation](https://docs.rs/regex/latest/regex/) for pattern syntax
/// and [replacement string syntax](https://docs.rs/regex/latest/regex/struct.Regex.html#replacement-string-syntax).
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
let regex = Regex::new(pattern).expect("Invalid regex pattern provided to with_key_remap");
self.key_remap.push((regex, replacement.into()));
self
}
/// Specifies a top-level key in the file under which the state dictionary is nested.
///
/// Some PyTorch files store the state dictionary within a larger structure (e.g., a dictionary).
/// Use this method if the weights are not at the root level of the file.
///
/// # Arguments
///
/// * `key` - The top-level key to access the state dictionary.
pub fn with_top_level_key(mut self, key: &str) -> Self {
self.top_level_key = Some(key.into());
self
}
/// Enables printing of debug information during loading.
pub fn with_debug_print(mut self) -> Self {
self.debug = true;
self
}
}
impl From<PathBuf> for LoadArgs {
fn from(val: PathBuf) -> Self {
LoadArgs::new(val)
}
}
impl From<String> for LoadArgs {
fn from(val: String) -> Self {
LoadArgs::new(val.into())
}
}
impl From<&str> for LoadArgs {
fn from(val: &str) -> Self {
LoadArgs::new(val.into())
}
}

View File

@@ -1,3 +0,0 @@
mod reader;
mod recorder;
pub use recorder::{AdapterType, LoadArgs, SafetensorsFileRecorder};

View File

@@ -1,72 +0,0 @@
use std::{collections::HashMap, path::Path};
use burn::{
record::{
PrecisionSettings,
serde::{
adapter::DefaultAdapter,
data::{remap, unflatten},
de::Deserializer,
},
},
tensor::backend::Backend,
};
use candle_core::{Device, safetensors};
use regex::Regex;
use serde::de::DeserializeOwned;
use super::super::common::adapter::PyTorchAdapter;
use super::recorder::AdapterType;
use crate::common::candle::{CandleTensor, Error, print_debug_info};
/// Deserializes model state from a safetensors file.
///
/// # Arguments
///
/// * `path` - Path to the safetensors file.
/// * `key_remap` - A vector of tuples containing regular expressions and replacement strings
/// for remapping tensor keys.
/// * `debug` - If true, prints debug information about the loaded tensors and remapped keys.
/// * `adapter_type` - Specifies the adapter to use for deserialization (e.g., PyTorch, None).
pub fn from_file<PS, D, B>(
path: &Path,
key_remap: Vec<(Regex, String)>,
debug: bool,
adapter_type: AdapterType,
) -> Result<D, Error>
where
D: DeserializeOwned,
PS: PrecisionSettings,
B: Backend,
{
// Load tensors from the safetensors file into a HashMap.
let tensors: HashMap<String, CandleTensor> = safetensors::load(path, &Device::Cpu)?
.into_iter()
.map(|(key, tensor)| (key, CandleTensor(tensor)))
.collect();
// Remap tensor keys based on the provided patterns.
let (tensors, remapped_keys) = remap(tensors, key_remap);
// Optionally print debug information about tensors and key remapping.
if debug {
print_debug_info(&tensors, remapped_keys);
}
// Convert the flat map of tensors into a nested data structure suitable for deserialization.
let nested_value = unflatten::<PS, _>(tensors)?;
// Deserialize the nested data structure into the target type using the specified adapter.
let value = match adapter_type {
AdapterType::PyTorch => D::deserialize(Deserializer::<PyTorchAdapter<PS, B>>::new(
nested_value,
true, // Allow unexpected fields by default? Might need clarification.
))?,
AdapterType::NoAdapter => {
D::deserialize(Deserializer::<DefaultAdapter>::new(nested_value, true))?
}
};
Ok(value)
}

View File

@@ -1,187 +0,0 @@
use core::marker::PhantomData;
use std::path::PathBuf;
use burn::{
record::{PrecisionSettings, Record, Recorder, RecorderError},
tensor::backend::Backend,
};
use regex::Regex;
use serde::{Serialize, de::DeserializeOwned};
use super::reader::from_file;
/// Recorder for loading HuggingFace Safetensors files (`.safetensors`) into Burn modules.
///
/// This recorder uses [LoadArgs] to configure loading behavior, such as key remapping.
#[derive(new, Debug, Default, Clone)]
pub struct SafetensorsFileRecorder<PS: PrecisionSettings> {
_settings: PhantomData<PS>,
}
impl<PS: PrecisionSettings, B: Backend> Recorder<B> for SafetensorsFileRecorder<PS> {
type Settings = PS;
type RecordArgs = PathBuf;
type RecordOutput = ();
type LoadArgs = LoadArgs;
fn save_item<I: Serialize>(
&self,
_item: I,
_file: Self::RecordArgs,
) -> Result<(), RecorderError> {
unimplemented!("save_item not implemented for SafetensorsFileRecorder")
}
fn load_item<I: DeserializeOwned>(
&self,
_file: &mut Self::LoadArgs,
) -> Result<I, RecorderError> {
unimplemented!("load_item not implemented for SafetensorsFileRecorder")
}
fn load<R: Record<B>>(
&self,
args: Self::LoadArgs,
device: &B::Device,
) -> Result<R, RecorderError> {
let item = from_file::<PS, R::Item<Self::Settings>, B>(
&args.file,
args.key_remap,
args.debug,
args.adapter_type,
)?;
Ok(R::from_item(item, device))
}
}
/// Arguments for loading a Safetensors file using [SafetensorsFileRecorder].
///
/// # Example
///
/// ```rust,ignore
/// use burn_import::safetensors::{AdapterType, LoadArgs, SafetensorsFileRecorder};
/// use burn::record::{FullPrecisionSettings, Recorder};
/// use std::path::PathBuf;
///
/// // Dummy model record structure
/// #[derive(Record, Default)]
/// struct MyModelRecord<B: Backend> {
/// // fields matching the tensor names in the file
/// }
///
/// let device = Default::default(); // Replace with your actual device
///
/// // Example assuming a file named 'model.safetensors' exists
/// let args = LoadArgs::new(PathBuf::from("model.safetensors"))
/// // Example: Remove "model.encoder." prefix from keys
/// .with_key_remap("model\\.encoder\\.(.*)", "$1")
/// .with_adapter_type(AdapterType::PyTorch) // Specify if adaptation is needed
/// .with_debug_print(); // Enable debug output
///
/// let record: MyModelRecord<MyBackend> = SafetensorsFileRecorder::<FullPrecisionSettings>::default()
/// .load(args, &device)
/// .expect("Should decode state successfully");
/// ```
#[derive(Debug, Clone)]
pub struct LoadArgs {
/// The path to the Safetensors file to load.
pub file: PathBuf,
/// A list of key remapping rules applied sequentially. Each tuple contains a
/// regular expression ([`Regex`]) to match keys and a replacement string.
/// See [regex::Regex::replace_all](https://docs.rs/regex/latest/regex/struct.Regex.html#method.replace_all)
/// for replacement syntax details.
pub key_remap: Vec<(Regex, String)>,
/// If true, prints debug information during the loading process.
pub debug: bool,
/// The type of adapter to apply for potential framework-specific tensor transformations
/// (e.g., transposing certain weights).
pub adapter_type: AdapterType,
}
/// Specifies the type of adapter to use for tensor loading.
///
/// Adapters handle potential differences in tensor formats or naming conventions
/// between the source framework and Burn.
#[derive(Debug, Clone, Default, Copy)]
pub enum AdapterType {
/// Adapts tensors assuming they originated from PyTorch.
#[default]
PyTorch,
/// Loads tensors directly without any specific adaptation.
NoAdapter,
}
impl LoadArgs {
/// Creates new `LoadArgs` for the given file path.
///
/// By default, no key remapping is applied, debug printing is off,
/// and the adapter type is [AdapterType::PyTorch].
///
/// # Arguments
///
/// * `file` - The path to the Safetensors file.
pub fn new(file: PathBuf) -> Self {
Self {
file,
key_remap: Vec::new(),
debug: false,
adapter_type: Default::default(),
}
}
/// Adds a key remapping rule.
///
/// Rules are applied in the order they are added.
///
/// # Arguments
///
/// * `pattern` - The regular expression pattern to match tensor keys.
/// * `replacement` - The replacement string. Capture groups like `$1`, `$2` can be used.
///
/// # Panics
///
/// Panics if `pattern` is not a valid regular expression.
///
/// See [Regex syntax](https://docs.rs/regex/latest/regex/#syntax) and
/// [replacement string syntax](https://docs.rs/regex/latest/regex/struct.Regex.html#replacement-string-syntax).
pub fn with_key_remap(mut self, pattern: &str, replacement: &str) -> Self {
let regex = Regex::new(pattern).expect("Invalid regex pattern provided");
self.key_remap.push((regex, replacement.to_string()));
self
}
/// Enables printing of debug information during loading.
pub fn with_debug_print(mut self) -> Self {
self.debug = true;
self
}
/// Sets the adapter type to use for loading tensors.
pub fn with_adapter_type(mut self, adapter_type: AdapterType) -> Self {
self.adapter_type = adapter_type;
self
}
}
impl From<PathBuf> for LoadArgs {
fn from(val: PathBuf) -> Self {
LoadArgs::new(val)
}
}
impl From<String> for LoadArgs {
fn from(val: String) -> Self {
LoadArgs::new(val.into())
}
}
impl From<&str> for LoadArgs {
fn from(val: &str) -> Self {
LoadArgs::new(val.into())
}
}

View File

@@ -1,52 +0,0 @@
[package]
authors = [
"Dilshod Tadjibaev (@antimora)",
"Nathaniel Simard (@nathanielsimard)",
]
description = "Library for importing ONNX models into the Burn framework"
edition.workspace = true
license.workspace = true
name = "burn-onnx"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-onnx"
documentation = "https://docs.rs/burn-onnx"
version.workspace = true
default-run = "onnx2burn"
[lints]
workspace = true
[features]
default = ["mmap"]
mmap = ["onnx-ir/mmap"]
[dependencies]
burn = { path = "../burn", version = "=0.21.0", default-features = false, features = [
"std",
] }
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0", default-features = false }
burn-store = { path = "../burn-store", version = "=0.21.0", default-features = false, features = [
"std",
"burnpack",
] }
onnx-ir = { path = "../onnx-ir", version = "=0.21.0", default-features = false }
derive-new = { workspace = true }
log = { workspace = true }
proc-macro2 = { workspace = true }
quote = { workspace = true }
rust-format = { workspace = true, features = ["pretty_please", "post_process"] }
tracing-core = { workspace = true }
tracing-subscriber = { workspace = true, features = [
"default",
"fmt",
"env-filter",
] }
syn = { workspace = true, features = ["parsing"] }
[dev-dependencies]
insta = { workspace = true }
[package.metadata.docs.rs]
features = ["default"]
rustdoc-args = ["--cfg", "docsrs"]

View File

@@ -1 +0,0 @@
../../LICENSE-APACHE

View File

@@ -1 +0,0 @@
../../LICENSE-MIT

View File

@@ -1,13 +0,0 @@
# Burn ONNX
The `burn-onnx` crate enables seamless integration of pre-trained models from popular machine
learning frameworks into the Burn ecosystem via the ONNX format.
## ONNX Contributor Resources
- [ONNX to Burn conversion guide](https://burn.dev/books/contributor/guides/onnx-to-burn-conversion-tool.html) -
Instructions for adding support for additional ONNX operators
- [ONNX tests README](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/onnx-tests/README.md) -
Testing procedures for ONNX operators
- [Supported ONNX Operators table](https://github.com/tracel-ai/burn/blob/main/crates/burn-onnx/SUPPORTED-ONNX-OPS.md) -
Complete list of currently supported ONNX operators

View File

@@ -1,416 +0,0 @@
# Supported ONNX Operators
This table lists the support status for ONNX operators in Burn. Note that some
entries marked with dimensional suffixes (such as `Conv1d`, `Conv2d`, etc.) or
other specialized names like `Linear` are not standard ONNX operators. These
represent Burn's implementation of dimension-specific versions of the
corresponding ONNX operators to make the mapping clearer between ONNX and Burn
functionality.
| ONNX OP | Import Support | Burn Support |
|----------------------------------|:--------------:|:------------:|
| [Abs][1] | ✅ | ✅ |
| [Acos][2] | ❌ | ✅ |
| [Acosh][3] | ❌ | ✅ |
| [Add][4] | ✅ | ✅ |
| [AffineGrid][195] | ❌ | ❌ |
| [And][5] | ✅ | ✅ |
| [ArgMax][6] | ✅ | ✅ |
| [ArgMin][7] | ✅ | ✅ |
| [Asin][8] | ❌ | ✅ |
| [Asinh][9] | ❌ | ✅ |
| [Atan][10] | ❌ | ✅ |
| [Atanh][11] | ❌ | ✅ |
| [Attention][194] | ✅ | ✅ |
| [AveragePool1d][12] | ✅ | ✅ |
| [AveragePool2d][12] | ✅ | ✅ |
| [BatchNormalization][14] | ✅ | ✅ |
| [Bernoulli][15] | ✅ | ✅ |
| [BitShift][16] | ✅ | ✅ |
| [BitwiseAnd][17] | ✅ | ✅ |
| [BitwiseNot][18] | ✅ | ✅ |
| [BitwiseOr][19] | ✅ | ✅ |
| [BitwiseXor][20] | ✅ | ✅ |
| [BlackmanWindow][21] | ❌ | ❌ |
| [Cast][22] | ✅ | ✅ |
| [CastLike][23] | ❌ | ❌ |
| [Ceil][24] | ✅ | ✅ |
| [Celu][25] | ❌ | ❌ |
| [CenterCropPad][26] | ❌ | ❌ |
| [Clip][27] | ✅ | ✅ |
| [Col2Im][28] | ❌ | ❌ |
| [Compress][29] | ❌ | ❌ |
| [Concat][30] | ✅ | ✅ |
| [ConcatFromSequence][31] | ❌ | ❌ |
| [Constant][32] | ✅ | ✅ |
| [ConstantOfShape][33] | ✅ | ✅ |
| [Conv1d][34] | ✅ | ✅ |
| [Conv2d][34] | ✅ | ✅ |
| [Conv3d][34] | ✅ | ✅ |
| [ConvInteger][37] | ❌ | ❌ |
| [ConvTranspose1d][38] | ✅ | ✅ |
| [ConvTranspose2d][38] | ✅ | ✅ |
| [ConvTranspose3d][38] | ✅ | ✅ |
| [Cos][39] | ✅ | ✅ |
| [Cosh][40] | ✅ | ✅ |
| [CumSum][41] | ✅ | ✅ |
| [DeformConv][196] | ❌ | ❌ |
| [DepthToSpace][42] | ✅ | ✅ |
| [DequantizeLinear][43] | ❌ | ❌ |
| [Det][44] | ❌ | ❌ |
| [DFT][45] | ❌ | ❌ |
| [Div][46] | ✅ | ✅ |
| [Dropout][47] | ✅ | ✅ |
| [DynamicQuantizeLinear][48] | ❌ | ❌ |
| [Einsum][49] | ❌ | ❌ |
| [Elu][50] | ❌ | ❌ |
| [Equal][51] | ✅ | ✅ |
| [Erf][52] | ✅ | ✅ |
| [Exp][53] | ✅ | ✅ |
| [Expand][54] | ✅ | ✅ |
| [EyeLike][55] | ✅ | ✅ |
| [Flatten][56] | ✅ | ✅ |
| [Floor][57] | ✅ | ✅ |
| [Gather][58] | ✅ | ✅ |
| [GatherElements][59] | ✅ | ✅ |
| [GatherND][60] | ❌ | ❌ |
| [Gelu][61] | ✅ | ✅ |
| [Gemm][62] | ✅ | ✅ |
| [GlobalAveragePool][63] | ✅ | ✅ |
| [GlobalLpPool][64] | ❌ | ❌ |
| [GlobalMaxPool][65] | ❌ | ❌ |
| [Greater][66] | ✅ | ✅ |
| [GreaterOrEqual][67] | ✅ | ✅ |
| [GridSample][68] | ✅ | ✅ |
| [GroupNormalization][69] | ✅ | ✅ |
| [GRU][70] | ❌ | ✅ |
| [HammingWindow][71] | ❌ | ❌ |
| [HannWindow][72] | ❌ | ❌ |
| [Hardmax][73] | ❌ | ❌ |
| [HardSigmoid][74] | ✅ | ✅ |
| [HardSwish][75] | ✅ | ✅ |
| [Identity][76] | ✅ | ✅ |
| [If][77] | ❌ | ✅ |
| [Im][78] | ❌ | ❌ |
| [ImageDecoder][197] | ❌ | ❌ |
| [InstanceNormalization][79] | ✅ | ✅ |
| [IsInf][80] | ✅ | ✅ |
| [IsNaN][81] | ✅ | ✅ |
| [LayerNormalization][82] | ✅ | ✅ |
| [LeakyRelu][83] | ✅ | ✅ |
| [Less][84] | ✅ | ✅ |
| [LessOrEqual][85] | ✅ | ✅ |
| Linear | ✅ | ✅ |
| [Log][87] | ✅ | ✅ |
| [LogSoftmax][88] | ✅ | ✅ |
| [Loop][89] | ✅ | ✅ |
| [LpNormalization][90] | ❌ | ❌ |
| [LpPool][91] | ❌ | ❌ |
| [LRN][92] | ❌ | ❌ |
| [LSTM][93] | ✅ | ✅ |
| [MatMul][94] | ✅ | ✅ |
| [MatMulInteger][95] | ✅ | ✅ |
| [Max][96] | ✅ | ✅ |
| [MaxPool1d][97] | ✅ | ✅ |
| [MaxPool2d][98] | ✅ | ✅ |
| [MaxRoiPool][99] | ❌ | ❌ |
| [MaxUnpool][100] | ❌ | ❌ |
| [Mean][101] | ✅ | ✅ |
| [MeanVarianceNormalization][102] | ❌ | ❌ |
| [MelWeightMatrix][103] | ❌ | ❌ |
| [Min][104] | ✅ | ✅ |
| [Mish][105] | ❌ | ❌ |
| [Mod][106] | ✅ | ✅ |
| [Mul][107] | ✅ | ✅ |
| [Multinomial][108] | ❌ | ❌ |
| [Neg][109] | ✅ | ✅ |
| [NegativeLogLikelihoodLoss][110] | ❌ | ❌ |
| [NonMaxSuppression][112] | ❌ | ❌ |
| [NonZero][113] | ✅ | ✅ |
| [Not][114] | ✅ | ✅ |
| [OneHot][115] | ✅ | ✅ |
| [Optional][116] | ❌ | ❌ |
| [OptionalGetElement][117] | ❌ | ❌ |
| [OptionalHasElement][118] | ❌ | ❌ |
| [Or][119] | ✅ | ✅ |
| [Pad][120] | ✅ | ✅ |
| [Pow][121] | ✅ | ✅ |
| [PRelu][122] | ✅ | ✅ |
| [QLinearConv][123] | ❌ | ❌ |
| [QLinearMatMul][124] | ❌ | ❌ |
| [QuantizeLinear][125] | ❌ | ❌ |
| [RMSNormalization][198] | ❌ | ❌ |
| [RNN][145] | ❌ | ✅ |
| [RandomNormal][126] | ✅ | ✅ |
| [RandomNormalLike][127] | ✅ | ✅ |
| [RandomUniform][128] | ✅ | ✅ |
| [RandomUniformLike][129] | ✅ | ✅ |
| [Range][130] | ✅ | ✅ |
| [Reciprocal][131] | ✅ | ✅ |
| [ReduceL][132] | ✅ | ✅ |
| [ReduceLogSum][133] | ✅ | ✅ |
| [ReduceLogSumExp][134] | ✅ | ✅ |
| [ReduceMax][135] | ✅ | ✅ |
| [ReduceMean][136] | ✅ | ✅ |
| [ReduceMin][137] | ✅ | ✅ |
| [ReduceProd][138] | ✅ | ✅ |
| [ReduceSum][139] | ✅ | ✅ |
| [ReduceSumSquare][140] | ✅ | ✅ |
| [RegexFullMatch][199] | ❌ | ❌ |
| [Relu][141] | ✅ | ✅ |
| [Reshape][142] | ✅ | ✅ |
| [Resize][143] | ✅ | ✅ |
| [ReverseSequence][144] | ❌ | ❌ |
| [RoiAlign][146] | ❌ | ❌ |
| [RotaryEmbedding][200] | ❌ | ❌ |
| [Round][147] | ✅ | ✅ |
| [Scan][148] | ✅ | ✅ |
| [Scatter][149] | ❌ | ✅ |
| [ScatterElements][150] | ❌ | ❌ |
| [ScatterND][151] | ❌ | ❌ |
| [Selu][152] | ❌ | ❌ |
| [SequenceAt][153] | ❌ | ❌ |
| [SequenceConstruct][154] | ❌ | ❌ |
| [SequenceEmpty][155] | ❌ | ❌ |
| [SequenceErase][156] | ❌ | ❌ |
| [SequenceInsert][157] | ❌ | ❌ |
| [SequenceLength][158] | ❌ | ❌ |
| [SequenceMap][159] | ❌ | ❌ |
| [Shape][160] | ✅ | ✅ |
| [Shrink][161] | ❌ | ❌ |
| [Sigmoid][162] | ✅ | ✅ |
| [Sign][163] | ✅ | ✅ |
| [Sin][164] | ✅ | ✅ |
| [Sinh][165] | ✅ | ✅ |
| [Size][166] | ✅ | ✅ |
| [Slice][167] | ✅ | ✅ |
| [Softmax][168] | ✅ | ✅ |
| [SoftmaxCrossEntropyLoss][169] | ❌ | ❌ |
| [Softplus][170] | ❌ | ❌ |
| [Softsign][171] | ❌ | ❌ |
| [SpaceToDepth][172] | ✅ | ✅ |
| [Split][173] | ✅ | ✅ |
| [SplitToSequence][174] | ❌ | ❌ |
| [Sqrt][175] | ✅ | ✅ |
| [Squeeze][176] | ✅ | ✅ |
| [STFT][177] | ❌ | ❌ |
| [StringConcat][201] | ❌ | ❌ |
| [StringNormalizer][178] | ❌ | ❌ |
| [StringSplit][202] | ❌ | ❌ |
| [Sub][179] | ✅ | ✅ |
| [Sum][180] | ✅ | ✅ |
| [Swish][203] | ❌ | ❌ |
| [Tan][181] | ✅ | ✅ |
| [Tanh][182] | ✅ | ✅ |
| [TensorScatter][204] | ❌ | ❌ |
| [TfIdfVectorizer][183] | ❌ | ❌ |
| [ThresholdedRelu][184] | ❌ | ❌ |
| [Tile][185] | ✅ | ✅ |
| [TopK][186] | ✅ | ✅ |
| [Transpose][187] | ✅ | ✅ |
| [Trilu][188] | ✅ | ✅ |
| [Unique][189] | ❌ | ❌ |
| [Upsample][190] | ❌ | ❌ |
| [Where][191] | ✅ | ✅ |
| [Xor][192] | ✅ | ✅ |
| [Unsqueeze][193] | ✅ | ✅ |
[1]: https://onnx.ai/onnx/operators/onnx__Abs.html "ONNX Abs"
[2]: https://onnx.ai/onnx/operators/onnx__Acos.html "ONNX Acos"
[3]: https://onnx.ai/onnx/operators/onnx__Acosh.html "ONNX Acosh"
[4]: https://onnx.ai/onnx/operators/onnx__Add.html "ONNX Add"
[5]: https://onnx.ai/onnx/operators/onnx__And.html "ONNX And"
[6]: https://onnx.ai/onnx/operators/onnx__ArgMax.html "ONNX ArgMax"
[7]: https://onnx.ai/onnx/operators/onnx__ArgMin.html "ONNX ArgMin"
[8]: https://onnx.ai/onnx/operators/onnx__Asin.html "ONNX Asin"
[9]: https://onnx.ai/onnx/operators/onnx__Asinh.html "ONNX Asinh"
[10]: https://onnx.ai/onnx/operators/onnx__Atan.html "ONNX Atan"
[11]: https://onnx.ai/onnx/operators/onnx__Atanh.html "ONNX Atanh"
[12]: https://onnx.ai/onnx/operators/onnx__AveragePool.html "ONNX AveragePool"
[14]: https://onnx.ai/onnx/operators/onnx__BatchNormalization.html "ONNX BatchNormalization"
[15]: https://onnx.ai/onnx/operators/onnx__Bernoulli.html "ONNX Bernoulli"
[16]: https://onnx.ai/onnx/operators/onnx__BitShift.html "ONNX BitShift"
[17]: https://onnx.ai/onnx/operators/onnx__BitwiseAnd.html "ONNX BitwiseAnd"
[18]: https://onnx.ai/onnx/operators/onnx__BitwiseNot.html "ONNX BitwiseNot"
[19]: https://onnx.ai/onnx/operators/onnx__BitwiseOr.html "ONNX BitwiseOr"
[20]: https://onnx.ai/onnx/operators/onnx__BitwiseXor.html "ONNX BitwiseXor"
[21]: https://onnx.ai/onnx/operators/onnx__BlackmanWindow.html "ONNX BlackmanWindow"
[22]: https://onnx.ai/onnx/operators/onnx__Cast.html "ONNX Cast"
[23]: https://onnx.ai/onnx/operators/onnx__CastLike.html "ONNX CastLike"
[24]: https://onnx.ai/onnx/operators/onnx__Ceil.html "ONNX Ceil"
[25]: https://onnx.ai/onnx/operators/onnx__Celu.html "ONNX Celu"
[26]: https://onnx.ai/onnx/operators/onnx__CenterCropPad.html "ONNX CenterCropPad"
[27]: https://onnx.ai/onnx/operators/onnx__Clip.html "ONNX Clip"
[28]: https://onnx.ai/onnx/operators/onnx__Col2Im.html "ONNX Col2Im"
[29]: https://onnx.ai/onnx/operators/onnx__Compress.html "ONNX Compress"
[30]: https://onnx.ai/onnx/operators/onnx__Concat.html "ONNX Concat"
[31]: https://onnx.ai/onnx/operators/onnx__ConcatFromSequence.html "ONNX ConcatFromSequence"
[32]: https://onnx.ai/onnx/operators/onnx__Constant.html "ONNX Constant"
[33]: https://onnx.ai/onnx/operators/onnx__ConstantOfShape.html "ONNX ConstantOfShape"
[34]: https://onnx.ai/onnx/operators/onnx__Conv.html "ONNX Conv"
[37]: https://onnx.ai/onnx/operators/onnx__ConvInteger.html "ONNX ConvInteger"
[38]: https://onnx.ai/onnx/operators/onnx__ConvTranspose.html "ONNX ConvTranspose"
[39]: https://onnx.ai/onnx/operators/onnx__Cos.html "ONNX Cos"
[40]: https://onnx.ai/onnx/operators/onnx__Cosh.html "ONNX Cosh"
[41]: https://onnx.ai/onnx/operators/onnx__CumSum.html "ONNX CumSum"
[42]: https://onnx.ai/onnx/operators/onnx__DepthToSpace.html "ONNX DepthToSpace"
[43]: https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html "ONNX DequantizeLinear"
[44]: https://onnx.ai/onnx/operators/onnx__Det.html "ONNX Det"
[45]: https://onnx.ai/onnx/operators/onnx__DFT.html "ONNX DFT"
[46]: https://onnx.ai/onnx/operators/onnx__Div.html "ONNX Div"
[47]: https://onnx.ai/onnx/operators/onnx__Dropout.html "ONNX Dropout"
[48]: https://onnx.ai/onnx/operators/onnx__DynamicQuantizeLinear.html "ONNX DynamicQuantizeLinear"
[49]: https://onnx.ai/onnx/operators/onnx__Einsum.html "ONNX Einsum"
[50]: https://onnx.ai/onnx/operators/onnx__Elu.html "ONNX Elu"
[51]: https://onnx.ai/onnx/operators/onnx__Equal.html "ONNX Equal"
[52]: https://onnx.ai/onnx/operators/onnx__Erf.html "ONNX Erf"
[53]: https://onnx.ai/onnx/operators/onnx__Exp.html "ONNX Exp"
[54]: https://onnx.ai/onnx/operators/onnx__Expand.html "ONNX Expand"
[55]: https://onnx.ai/onnx/operators/onnx__EyeLike.html "ONNX EyeLike"
[56]: https://onnx.ai/onnx/operators/onnx__Flatten.html "ONNX Flatten"
[57]: https://onnx.ai/onnx/operators/onnx__Floor.html "ONNX Floor"
[58]: https://onnx.ai/onnx/operators/onnx__Gather.html "ONNX Gather"
[59]: https://onnx.ai/onnx/operators/onnx__GatherElements.html "ONNX GatherElements"
[60]: https://onnx.ai/onnx/operators/onnx__GatherND.html "ONNX GatherND"
[61]: https://onnx.ai/onnx/operators/onnx__Gelu.html "ONNX Gelu"
[62]: https://onnx.ai/onnx/operators/onnx__Gemm.html "ONNX Gemm (Linear Layer)"
[63]: https://onnx.ai/onnx/operators/onnx__GlobalAveragePool.html "ONNX GlobalAveragePool"
[64]: https://onnx.ai/onnx/operators/onnx__GlobalLpPool.html "ONNX GlobalLpPool"
[65]: https://onnx.ai/onnx/operators/onnx__GlobalMaxPool.html "ONNX GlobalMaxPool"
[66]: https://onnx.ai/onnx/operators/onnx__Greater.html "ONNX Greater"
[67]: https://onnx.ai/onnx/operators/onnx__GreaterOrEqual.html "ONNX GreaterOrEqual"
[68]: https://onnx.ai/onnx/operators/onnx__GridSample.html "ONNX GridSample"
[69]: https://onnx.ai/onnx/operators/onnx__GroupNormalization.html "ONNX GroupNormalization"
[70]: https://onnx.ai/onnx/operators/onnx__GRU.html "ONNX GRU"
[71]: https://onnx.ai/onnx/operators/onnx__HammingWindow.html "ONNX HammingWindow"
[72]: https://onnx.ai/onnx/operators/onnx__HannWindow.html "ONNX HannWindow"
[73]: https://onnx.ai/onnx/operators/onnx__Hardmax.html "ONNX Hardmax"
[74]: https://onnx.ai/onnx/operators/onnx__HardSigmoid.html "ONNX HardSigmoid"
[75]: https://onnx.ai/onnx/operators/onnx__HardSwish.html "ONNX HardSwish"
[76]: https://onnx.ai/onnx/operators/onnx__Identity.html "ONNX Identity"
[77]: https://onnx.ai/onnx/operators/onnx__If.html "ONNX If"
[78]: https://onnx.ai/onnx/operators/onnx__Im.html "ONNX Im"
[79]: https://onnx.ai/onnx/operators/onnx__InstanceNormalization.html "ONNX InstanceNormalization"
[80]: https://onnx.ai/onnx/operators/onnx__IsInf.html "ONNX IsInf"
[81]: https://onnx.ai/onnx/operators/onnx__IsNaN.html "ONNX IsNaN"
[82]: https://onnx.ai/onnx/operators/onnx__LayerNormalization.html "ONNX LayerNormalization"
[83]: https://onnx.ai/onnx/operators/onnx__LeakyRelu.html "ONNX LeakyRelu"
[84]: https://onnx.ai/onnx/operators/onnx__Less.html "ONNX Less"
[85]: https://onnx.ai/onnx/operators/onnx__LessOrEqual.html "ONNX LessOrEqual"
[87]: https://onnx.ai/onnx/operators/onnx__Log.html "ONNX Log"
[88]: https://onnx.ai/onnx/operators/onnx__LogSoftmax.html "ONNX LogSoftmax"
[89]: https://onnx.ai/onnx/operators/onnx__Loop.html "ONNX Loop"
[90]: https://onnx.ai/onnx/operators/onnx__LpNormalization.html "ONNX LpNormalization"
[91]: https://onnx.ai/onnx/operators/onnx__LpPool.html "ONNX LpPool"
[92]: https://onnx.ai/onnx/operators/onnx__LRN.html "ONNX LRN"
[93]: https://onnx.ai/onnx/operators/onnx__LSTM.html "ONNX LSTM"
[94]: https://onnx.ai/onnx/operators/onnx__MatMul.html "ONNX MatMul"
[95]: https://onnx.ai/onnx/operators/onnx__MatMulInteger.html "ONNX MatMulInteger"
[96]: https://onnx.ai/onnx/operators/onnx__Max.html "ONNX Max"
[97]: https://onnx.ai/onnx/operators/onnx__MaxPool.html "ONNX MaxPool1d"
[98]: https://onnx.ai/onnx/operators/onnx__MaxPool.html "ONNX MaxPool2d"
[99]: https://onnx.ai/onnx/operators/onnx__MaxRoiPool.html "ONNX MaxRoiPool"
[100]: https://onnx.ai/onnx/operators/onnx__MaxUnpool.html "ONNX MaxUnpool"
[101]: https://onnx.ai/onnx/operators/onnx__Mean.html "ONNX Mean"
[102]: https://onnx.ai/onnx/operators/onnx__MeanVarianceNormalization.html "ONNX MeanVarianceNormalization"
[103]: https://onnx.ai/onnx/operators/onnx__MelWeightMatrix.html "ONNX MelWeightMatrix"
[104]: https://onnx.ai/onnx/operators/onnx__Min.html "ONNX Min"
[105]: https://onnx.ai/onnx/operators/onnx__Mish.html "ONNX Mish"
[106]: https://onnx.ai/onnx/operators/onnx__Mod.html "ONNX Mod"
[107]: https://onnx.ai/onnx/operators/onnx__Mul.html "ONNX Mul"
[108]: https://onnx.ai/onnx/operators/onnx__Multinomial.html "ONNX Multinomial"
[109]: https://onnx.ai/onnx/operators/onnx__Neg.html "ONNX Neg"
[110]: https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html "ONNX NegativeLogLikelihoodLoss"
[112]: https://onnx.ai/onnx/operators/onnx__NonMaxSuppression.html "ONNX NonMaxSuppression"
[113]: https://onnx.ai/onnx/operators/onnx__NonZero.html "ONNX NonZero"
[114]: https://onnx.ai/onnx/operators/onnx__Not.html "ONNX Not"
[115]: https://onnx.ai/onnx/operators/onnx__OneHot.html "ONNX OneHot"
[116]: https://onnx.ai/onnx/operators/onnx__Optional.html "ONNX Optional"
[117]: https://onnx.ai/onnx/operators/onnx__OptionalGetElement.html "ONNX OptionalGetElement"
[118]: https://onnx.ai/onnx/operators/onnx__OptionalHasElement.html "ONNX OptionalHasElement"
[119]: https://onnx.ai/onnx/operators/onnx__Or.html "ONNX Or"
[120]: https://onnx.ai/onnx/operators/onnx__Pad.html "ONNX Pad"
[121]: https://onnx.ai/onnx/operators/onnx__Pow.html "ONNX Pow"
[122]: https://onnx.ai/onnx/operators/onnx__PRelu.html "ONNX PRelu"
[123]: https://onnx.ai/onnx/operators/onnx__QLinearConv "ONNX QLinearConv"
[124]: https://onnx.ai/onnx/operators/onnx__QLinearMatMul.html "ONNX QLinearMatMul"
[125]: https://onnx.ai/onnx/operators/onnx__QuantizeLinear.html "ONNX QuantizeLinear"
[126]: https://onnx.ai/onnx/operators/onnx__RandomNormal.html "ONNX RandomNormal"
[127]: https://onnx.ai/onnx/operators/onnx__RandomNormalLike.html "ONNX RandomNormalLike"
[128]: https://onnx.ai/onnx/operators/onnx__RandomUniform.html "ONNX RandomUniform"
[129]: https://onnx.ai/onnx/operators/onnx__RandomUniformLike.html "ONNX RandomUniformLike"
[130]: https://onnx.ai/onnx/operators/onnx__Range.html "ONNX Range"
[131]: https://onnx.ai/onnx/operators/onnx__Reciprocal.html "ONNX Reciprocal"
[132]: https://onnx.ai/onnx/operators/onnx__ReduceL1.html "ONNX ReduceL"
[133]: https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html "ONNX ReduceLogSum"
[134]: https://onnx.ai/onnx/operators/onnx__ReduceLogSumExp.html "ONNX ReduceLogSumExp"
[135]: https://onnx.ai/onnx/operators/onnx__ReduceMax.html "ONNX ReduceMax"
[136]: https://onnx.ai/onnx/operators/onnx__ReduceMean.html "ONNX ReduceMean"
[137]: https://onnx.ai/onnx/operators/onnx__ReduceMin.html "ONNX ReduceMin"
[138]: https://onnx.ai/onnx/operators/onnx__ReduceProd.html "ONNX ReduceProd"
[139]: https://onnx.ai/onnx/operators/onnx__ReduceSum.html "ONNX ReduceSum"
[140]: https://onnx.ai/onnx/operators/onnx__ReduceSumSquare.html "ONNX ReduceSumSquare"
[141]: https://onnx.ai/onnx/operators/onnx__Relu.html "ONNX Relu"
[142]: https://onnx.ai/onnx/operators/onnx__Reshape.html "ONNX Reshape"
[143]: https://onnx.ai/onnx/operators/onnx__Resize.html "ONNX Resize"
[144]: https://onnx.ai/onnx/operators/onnx__ReverseSequence.html "ONNX ReverseSequence"
[145]: https://onnx.ai/onnx/operators/onnx__RNN.html "ONNX RNN"
[146]: https://onnx.ai/onnx/operators/onnx__RoiAlign.html "ONNX RoiAlign"
[147]: https://onnx.ai/onnx/operators/onnx__Round.html "ONNX Round"
[148]: https://onnx.ai/onnx/operators/onnx__Scan.html "ONNX Scan"
[149]: https://onnx.ai/onnx/operators/onnx__Scatter.html "ONNX Scatter"
[150]: https://onnx.ai/onnx/operators/onnx__ScatterElements.html "ONNX ScatterElements"
[151]: https://onnx.ai/onnx/operators/onnx__ScatterND.html "ONNX ScatterND"
[152]: https://onnx.ai/onnx/operators/onnx__Selu.html "ONNX Selu"
[153]: https://onnx.ai/onnx/operators/onnx__SequenceAt.html "ONNX SequenceAt"
[154]: https://onnx.ai/onnx/operators/onnx__SequenceConstruct.html "ONNX SequenceConstruct"
[155]: https://onnx.ai/onnx/operators/onnx__SequenceEmpty.html "ONNX SequenceEmpty"
[156]: https://onnx.ai/onnx/operators/onnx__SequenceErase.html "ONNX SequenceErase"
[157]: https://onnx.ai/onnx/operators/onnx__SequenceInsert.html "ONNX SequenceInsert"
[158]: https://onnx.ai/onnx/operators/onnx__SequenceLength.html "ONNX SequenceLength"
[159]: https://onnx.ai/onnx/operators/onnx__SequenceMap.html "ONNX SequenceMap"
[160]: https://onnx.ai/onnx/operators/onnx__Shape.html "ONNX Shape"
[161]: https://onnx.ai/onnx/operators/onnx__Shrink.html "ONNX Shrink"
[162]: https://onnx.ai/onnx/operators/onnx__Sigmoid.html "ONNX Sigmoid"
[163]: https://onnx.ai/onnx/operators/onnx__Sign.html "ONNX Sign"
[164]: https://onnx.ai/onnx/operators/onnx__Sin.html "ONNX Sin"
[165]: https://onnx.ai/onnx/operators/onnx__Sinh.html "ONNX Sinh"
[166]: https://onnx.ai/onnx/operators/onnx__Size.html "ONNX Size"
[167]: https://onnx.ai/onnx/operators/onnx__Slice.html "ONNX Slice"
[168]: https://onnx.ai/onnx/operators/onnx__Softmax.html "ONNX Softmax"
[169]: https://onnx.ai/onnx/operators/onnx__SoftmaxCrossEntropyLoss.html "ONNX SoftmaxCrossEntropyLoss"
[170]: https://onnx.ai/onnx/operators/onnx__Softplus.html "ONNX Softplus"
[171]: https://onnx.ai/onnx/operators/onnx__Softsign.html "ONNX Softsign"
[172]: https://onnx.ai/onnx/operators/onnx__SpaceToDepth.html "ONNX SpaceToDepth"
[173]: https://onnx.ai/onnx/operators/onnx__Split.html "ONNX Split"
[174]: https://onnx.ai/onnx/operators/onnx__SplitToSequence.html "ONNX SplitToSequence"
[175]: https://onnx.ai/onnx/operators/onnx__Sqrt.html "ONNX Sqrt"
[176]: https://onnx.ai/onnx/operators/onnx__Squeeze.html "ONNX Squeeze"
[177]: https://onnx.ai/onnx/operators/onnx__STFT.html "ONNX STFT"
[178]: https://onnx.ai/onnx/operators/onnx__StringNormalizer.html "ONNX StringNormalizer"
[179]: https://onnx.ai/onnx/operators/onnx__Sub.html "ONNX Sub"
[180]: https://onnx.ai/onnx/operators/onnx__Sum.html "ONNX Sum"
[181]: https://onnx.ai/onnx/operators/onnx__Tan.html "ONNX Tan"
[182]: https://onnx.ai/onnx/operators/onnx__Tanh.html "ONNX Tanh"
[183]: https://onnx.ai/onnx/operators/onnx__TfIdfVectorizer.html "ONNX TfIdfVectorizer"
[184]: https://onnx.ai/onnx/operators/onnx__ThresholdedRelu.html "ONNX ThresholdedRelu"
[185]: https://onnx.ai/onnx/operators/onnx__Tile.html "ONNX Tile"
[186]: https://onnx.ai/onnx/operators/onnx__TopK.html "ONNX TopK"
[187]: https://onnx.ai/onnx/operators/onnx__Transpose.html "ONNX Transpose"
[188]: https://onnx.ai/onnx/operators/onnx__Trilu.html "ONNX Trilu"
[189]: https://onnx.ai/onnx/operators/onnx__Unique.html "ONNX Unique"
[190]: https://onnx.ai/onnx/operators/onnx__Upsample.html "ONNX Upsample"
[191]: https://onnx.ai/onnx/operators/onnx__Where.html "ONNX Where"
[192]: https://onnx.ai/onnx/operators/onnx__Xor.html "ONNX Xor"
[193]: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html "ONNX Unsqueeze"
[194]: https://onnx.ai/onnx/operators/onnx__Attention.html "ONNX Attention"
[195]: https://onnx.ai/onnx/operators/onnx__AffineGrid.html "ONNX AffineGrid"
[196]: https://onnx.ai/onnx/operators/onnx__DeformConv.html "ONNX DeformConv"
[197]: https://onnx.ai/onnx/operators/onnx__ImageDecoder.html "ONNX ImageDecoder"
[198]: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html "ONNX RMSNormalization"
[199]: https://onnx.ai/onnx/operators/onnx__RegexFullMatch.html "ONNX RegexFullMatch"
[200]: https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html "ONNX RotaryEmbedding"
[201]: https://onnx.ai/onnx/operators/onnx__StringConcat.html "ONNX StringConcat"
[202]: https://onnx.ai/onnx/operators/onnx__StringSplit.html "ONNX StringSplit"
[203]: https://onnx.ai/onnx/operators/onnx__Swish.html "ONNX Swish"
[204]: https://onnx.ai/onnx/operators/onnx__TensorScatter.html "ONNX TensorScatter"

View File

@@ -1,25 +0,0 @@
Cargo.lock
# Ignore model artifacts and temporary resources in all subdirectories
*/artifacts/
# Ignore downloaded model files
**/*.onnx
**/*.pt
**/*.pth
**/*.ckpt
**/*.safetensors
**/*.pb
**/*.h5
# Python cache
**/__pycache__/
**/*.pyc
**/*.pyo
**/*.pyd
**/.Python
# UV/pip
**/.venv/
**/venv/
**/*.egg-info/

View File

@@ -1,66 +0,0 @@
# Model Checks
This directory contains model verification and validation tests for burn-onnx. Each subdirectory
represents a different model that we test to ensure burn-onnx can correctly:
1. Import ONNX models
2. Generate Rust code from the models
3. Build and run the generated code
## Purpose
The model-checks serve as integration tests to verify that burn-onnx works correctly with
real-world models. These tests help catch regressions and ensure compatibility with various ONNX
operators and model architectures.
## Structure
Each model directory typically contains:
- Model download/preparation script (e.g., `get_model.py`)
- `build.rs` - Build script that uses burn-onnx to generate Rust code
- `src/main.rs` - Test code that runs the generated model
- `Cargo.toml` - Package configuration
- `artifacts/` - Directory for downloaded ONNX models (created by the script)
Generated files (not tracked in git):
- `target/` - Build artifacts and generated model code
## Two-Step Process
### Step 1: Download and Prepare the Model
First, download the model and convert it to the required ONNX format:
```bash
cd model-checks/<model-name>
python get_model.py
# or using uv:
uv run get_model.py
```
The model preparation script typically:
- Downloads the model (if not already present)
- Converts it to ONNX format with the appropriate opset version
- Validates the model structure
- Saves the prepared model to `artifacts/`
Scripts are designed to skip downloading if the ONNX model already exists, saving time and
bandwidth.
### Step 2: Build and Run the Model
Once the ONNX model is ready, build and run the Rust code:
```bash
cargo build
cargo run
```
The build process will:
- Check that the ONNX model exists (with helpful error messages if not)
- Generate Rust code from the ONNX model using burn-onnx
- Compile the generated code

View File

@@ -1,26 +0,0 @@
[package]
name = "burn-onnx-model-checks-albert"
version = "0.1.0"
edition = "2024"
publish = false
[workspace]
[features]
default = ["tch"]
ndarray = []
tch = []
wgpu = []
metal = []
[dependencies]
burn = { path = "../../../../crates/burn", features = [
"ndarray",
"tch",
"wgpu",
"metal",
] }
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
[build-dependencies]
burn-onnx = { path = "../../../burn-onnx" }

View File

@@ -1,69 +0,0 @@
# ALBERT Model Checks
This crate provides a unified interface for testing ALBERT model variants with Burn.
## Supported Models
- `albert-base-v2` - ALBERT Base v2 from HuggingFace (https://huggingface.co/albert/albert-base-v2)
## Usage
### 1. Download and prepare a model
```bash
# Using Python directly
python get_model.py --model albert-base-v2
# Or using uv
uv run get_model.py --model albert-base-v2
# List available models
uv run get_model.py --list
```
### 2. Build and run the model test
```bash
# Build the model
ALBERT_MODEL=albert-base-v2 cargo build
# Run the test
ALBERT_MODEL=albert-base-v2 cargo run --release
```
## Directory Structure
```
albert/
├── artifacts/ # Downloaded ONNX models and test data
│ ├── albert-base-v2_opset16.onnx
│ ├── albert-base-v2_test_data.pt
│ └── ...
├── src/
│ └── main.rs # Test runner
├── build.rs # Build script that generates model code
├── get_model.py # Model download and preparation script
└── Cargo.toml
```
## Model Architecture
ALBERT (A Lite BERT) is a lighter version of BERT that uses parameter-sharing techniques to reduce
model size while maintaining performance. The model has:
- **Inputs**:
- `input_ids`: Token IDs (shape: [batch_size, sequence_length])
- `attention_mask`: Attention mask (shape: [batch_size, sequence_length])
- `token_type_ids`: Token type IDs (shape: [batch_size, sequence_length])
- **Outputs**:
- `last_hidden_state`: Sequence of hidden states (shape: [batch_size, sequence_length,
hidden_size])
- `pooler_output`: Pooled output for classification tasks (shape: [batch_size, hidden_size])
## Notes
- The default sequence length is 128 tokens
- ALBERT Base v2 has a hidden size of 768
- The model uses ONNX opset 16
- Test data is generated with random inputs for reproducibility (seed=42)

View File

@@ -1,85 +0,0 @@
use burn_onnx::ModelGen;
use std::env;
use std::fs;
use std::path::Path;
fn main() {
// Supported models
let supported_models = vec!["albert-base-v2"];
// Get the model name from environment variable (required)
let model_name = env::var("ALBERT_MODEL").unwrap_or_else(|_| {
eprintln!("Error: ALBERT_MODEL environment variable is not set.");
eprintln!();
eprintln!("Please specify which ALBERT model to build:");
eprintln!(" ALBERT_MODEL=albert-base-v2 cargo build");
eprintln!();
eprintln!("Available models: {}", supported_models.join(", "));
std::process::exit(1);
});
if !supported_models.contains(&model_name.as_str()) {
eprintln!(
"Error: Unsupported model '{}'. Supported models: {:?}",
model_name, supported_models
);
std::process::exit(1);
}
let onnx_path = format!("artifacts/{}_opset16.onnx", model_name);
let test_data_path = format!("artifacts/{}_test_data.pt", model_name);
// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed={}", test_data_path);
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=ALBERT_MODEL");
// Check if the ONNX model file exists
if !Path::new(&onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!(
"Please run the following command to download and prepare the {} model:",
model_name
);
eprintln!(" python get_model.py --model {}", model_name);
eprintln!();
eprintln!("Or if you prefer using uv:");
eprintln!(" uv run get_model.py --model {}", model_name);
eprintln!();
eprintln!("Available models: {}", supported_models.join(", "));
std::process::exit(1);
}
// Generate the model code from the ONNX file
ModelGen::new()
.input(&onnx_path)
.out_dir("model/")
.run_from_script();
// Write the model name to a file so main.rs can access it
let out_dir = env::var("OUT_DIR").unwrap();
let model_info_path = Path::new(&out_dir).join("model_info.rs");
// Generate the include path for the model
let model_include = format!(
"include!(concat!(env!(\"OUT_DIR\"), \"/model/{}_opset16.rs\"));",
model_name
);
fs::write(
model_info_path,
format!(
r#"pub const MODEL_NAME: &str = "{}";
pub const TEST_DATA_FILE: &str = "{}_test_data.pt";
// Include the generated model
pub mod albert_model {{
{}
}}"#,
model_name, model_name, model_include
),
)
.expect("Failed to write model info");
}

View File

@@ -1,234 +0,0 @@
#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
# "onnx>=1.17.0",
# "onnxruntime>=1.18.0",
# "transformers>=4.44.0",
# "sentencepiece>=0.2.0",
# "numpy",
# "torch",
# ]
# ///
import os
import sys
import onnx
from onnx import shape_inference, version_converter
import numpy as np
from pathlib import Path
import argparse
# Supported ALBERT models configuration
SUPPORTED_MODELS = {
'albert-base-v2': {
'hf_name': 'albert/albert-base-v2',
'display_name': 'ALBERT Base v2',
'seq_length': 128,
},
}
def download_and_convert_model(model_name, output_path):
"""Download ALBERT model from HuggingFace and export to ONNX format."""
from transformers import AlbertModel, AlbertTokenizer
import torch
model_config = SUPPORTED_MODELS[model_name]
display_name = model_config['display_name']
hf_name = model_config['hf_name']
seq_length = model_config['seq_length']
print(f"Downloading {display_name} model from HuggingFace...")
tokenizer = AlbertTokenizer.from_pretrained(hf_name)
model = AlbertModel.from_pretrained(hf_name)
model.eval()
print("Exporting to ONNX format...")
# Create dummy inputs
dummy_text = "This is a sample text for ONNX export."
inputs = tokenizer(
dummy_text,
padding='max_length',
max_length=seq_length,
truncation=True,
return_tensors="pt"
)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
token_type_ids = inputs['token_type_ids']
# Export to ONNX
torch.onnx.export(
model,
(input_ids, attention_mask, token_type_ids),
output_path,
input_names=['input_ids', 'attention_mask', 'token_type_ids'],
output_names=['last_hidden_state', 'pooler_output'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence'},
'attention_mask': {0: 'batch_size', 1: 'sequence'},
'token_type_ids': {0: 'batch_size', 1: 'sequence'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence'},
'pooler_output': {0: 'batch_size'},
},
opset_version=16,
do_constant_folding=True,
)
if not output_path.exists():
raise FileNotFoundError(f"Failed to create ONNX file at {output_path}")
def process_model(input_path, output_path, target_opset=16):
"""Load, upgrade opset, and apply shape inference to model."""
print(f"Loading model from {input_path}...")
model = onnx.load(input_path)
# Check and upgrade opset if needed
current_opset = model.opset_import[0].version
if current_opset < target_opset:
print(f"Upgrading opset from {current_opset} to {target_opset}...")
model = version_converter.convert_version(model, target_opset)
# Apply shape inference
print("Applying shape inference...")
model = shape_inference.infer_shapes(model)
# Save processed model
onnx.save(model, output_path)
print(f"✓ Processed model saved to: {output_path}")
return model
def generate_test_data(model_path, output_path, model_name):
"""Generate test input/output data and save as PyTorch tensors."""
import torch
import onnxruntime as ort
print("\nGenerating test data...")
model_config = SUPPORTED_MODELS[model_name]
seq_length = model_config['seq_length']
# Create reproducible test input
np.random.seed(42)
batch_size = 1
# Generate random token IDs (typical vocabulary size is 30000 for ALBERT)
input_ids = np.random.randint(0, 30000, size=(batch_size, seq_length), dtype=np.int64)
attention_mask = np.ones((batch_size, seq_length), dtype=np.int64)
token_type_ids = np.zeros((batch_size, seq_length), dtype=np.int64)
print(f" Input shapes:")
print(f" input_ids: {input_ids.shape}")
print(f" attention_mask: {attention_mask.shape}")
print(f" token_type_ids: {token_type_ids.shape}")
# Run inference to get output
session = ort.InferenceSession(model_path)
outputs = session.run(
None,
{
'input_ids': input_ids,
'attention_mask': attention_mask,
'token_type_ids': token_type_ids,
}
)
# Save as PyTorch tensors
test_data = {
'input_ids': torch.from_numpy(input_ids),
'attention_mask': torch.from_numpy(attention_mask),
'token_type_ids': torch.from_numpy(token_type_ids),
'last_hidden_state': torch.from_numpy(outputs[0]),
'pooler_output': torch.from_numpy(outputs[1]),
}
torch.save(test_data, output_path)
print(f" ✓ Test data saved to: {output_path}")
print(f" last_hidden_state shape: {outputs[0].shape}")
print(f" pooler_output shape: {outputs[1].shape}")
def main():
parser = argparse.ArgumentParser(description='ALBERT Model Preparation Tool')
parser.add_argument('--model', type=str, default='albert-base-v2',
choices=list(SUPPORTED_MODELS.keys()),
help=f'ALBERT model to download and prepare (default: albert-base-v2). Choices: {", ".join(SUPPORTED_MODELS.keys())}')
parser.add_argument('--list', action='store_true',
help='List all supported models')
args = parser.parse_args()
if args.list:
print("Supported ALBERT models:")
for model_id, config in SUPPORTED_MODELS.items():
print(f" - {model_id:20s} ({config['display_name']})")
return
model_name = args.model
display_name = SUPPORTED_MODELS[model_name]['display_name']
print("=" * 60)
print(f"{display_name} Model Preparation Tool")
print("=" * 60)
# Setup paths
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
original_path = artifacts_dir / f"{model_name}.onnx"
processed_path = artifacts_dir / f"{model_name}_opset16.onnx"
test_data_path = artifacts_dir / f"{model_name}_test_data.pt"
# Check if we already have everything
if processed_path.exists() and test_data_path.exists():
print(f"\n✓ All files already exist for {display_name}:")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print("\nNothing to do!")
return
# Download and convert if needed
if not original_path.exists() and not processed_path.exists():
print(f"\nStep 1: Downloading and converting {display_name} model...")
download_and_convert_model(model_name, original_path)
# Process model if needed
if not processed_path.exists():
print("\nStep 2: Processing model...")
process_model(original_path, processed_path, target_opset=16)
# Clean up original if we have the processed version
if original_path.exists():
original_path.unlink()
# Generate test data if needed
if not test_data_path.exists():
print("\nStep 3: Generating test data...")
generate_test_data(processed_path, test_data_path, model_name)
print("\n" + "=" * 60)
print(f"{display_name} model preparation completed!")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print("=" * 60)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n⚠ Operation cancelled by user.")
sys.exit(1)
except Exception as e:
print(f"\n✗ Error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,209 +0,0 @@
extern crate alloc;
use burn::module::{Initializer, Param};
use burn::prelude::*;
use burn_store::{ModuleSnapshot, PytorchStore};
use std::path::Path;
use std::time::Instant;
#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;
#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;
#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;
#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;
// Import model info generated by build.rs (includes the albert_model module)
include!(concat!(env!("OUT_DIR"), "/model_info.rs"));
// Use the albert_model module from model_info.rs
use albert_model::Model;
#[derive(Debug, Module)]
struct TestData<B: Backend> {
input_ids: Param<Tensor<B, 2, Int>>,
attention_mask: Param<Tensor<B, 2, Int>>,
token_type_ids: Param<Tensor<B, 2, Int>>,
last_hidden_state: Param<Tensor<B, 3>>,
pooler_output: Param<Tensor<B, 2>>,
}
impl<B: Backend> TestData<B> {
fn new(device: &B::Device) -> Self {
use burn::module::ParamId;
// Initialize with correct shapes matching the test data
// ALBERT base uses sequence_length=128, hidden_size=768
// Note: Initializer only works for float tensors, Int tensors need manual init
Self {
input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
token_type_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
last_hidden_state: Initializer::Zeros.init([1, 128, 768], device),
pooler_output: Initializer::Zeros.init([1, 768], device),
}
}
}
fn get_model_display_name(model_name: &str) -> &str {
match model_name {
"albert-base-v2" => "ALBERT Base v2",
_ => model_name,
}
}
fn main() {
// MODEL_NAME is set at build time from ALBERT_MODEL env var
let model_name = MODEL_NAME;
let display_name = get_model_display_name(model_name);
println!("========================================");
println!("{} Burn Model Test", display_name);
println!("========================================\n");
// Check if artifacts exist
let artifacts_dir = Path::new("artifacts");
if !artifacts_dir.exists() {
eprintln!("Error: artifacts directory not found!");
eprintln!("Please run get_model.py first to download the model and test data.");
eprintln!("Example: uv run get_model.py --model {}", model_name);
std::process::exit(1);
}
// Check if model files exist for this specific model
let model_file = artifacts_dir.join(format!("{}_opset16.onnx", model_name));
let test_data_file = artifacts_dir.join(format!("{}_test_data.pt", model_name));
if !model_file.exists() || !test_data_file.exists() {
eprintln!("Error: Model files not found for {}!", display_name);
eprintln!("Please run: uv run get_model.py --model {}", model_name);
eprintln!();
eprintln!("Available models:");
eprintln!(" - albert-base-v2");
std::process::exit(1);
}
// Initialize the model (without weights for now)
println!("Initializing {} model...", display_name);
let start = Instant::now();
let device = Default::default();
let model: Model<MyBackend> = Model::default();
let init_time = start.elapsed();
println!(" Model initialized in {:.2?}", init_time);
// Save model structure to file
let model_txt_path = artifacts_dir.join(format!("{}_model.txt", model_name));
println!(
"\nSaving model structure to {}...",
model_txt_path.display()
);
let model_str = format!("{}", model);
std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file");
println!(" Model structure saved");
// Load test data from PyTorch file
println!("\nLoading test data from {}...", test_data_file.display());
let start = Instant::now();
let mut test_data = TestData::<MyBackend>::new(&device);
let mut store = PytorchStore::from_file(&test_data_file);
test_data.load_from(&mut store).expect("Failed to load test data");
let load_time = start.elapsed();
println!(" Data loaded in {:.2?}", load_time);
// Get the input tensors from test data
let input_ids = test_data.input_ids.val();
let attention_mask = test_data.attention_mask.val();
let token_type_ids = test_data.token_type_ids.val();
println!(" Loaded input tensors:");
println!(" input_ids shape: {:?}", input_ids.shape().dims);
println!(
" attention_mask shape: {:?}",
attention_mask.shape().dims
);
println!(
" token_type_ids shape: {:?}",
token_type_ids.shape().dims
);
// Get the reference outputs from test data
let reference_last_hidden = test_data.last_hidden_state.val();
let reference_pooler = test_data.pooler_output.val();
println!(" Loaded reference outputs:");
println!(
" last_hidden_state shape: {:?}",
reference_last_hidden.shape().dims
);
println!(
" pooler_output shape: {:?}",
reference_pooler.shape().dims
);
// Run inference with the loaded input
println!("\nRunning model inference with test input...");
let start = Instant::now();
let outputs = model.forward(input_ids, attention_mask, token_type_ids);
let inference_time = start.elapsed();
println!(" Inference completed in {:.2?}", inference_time);
// ALBERT models typically return (last_hidden_state, pooler_output)
// The outputs tuple should have 2 elements
println!("\n Model output shapes:");
println!(
" output 0 (last_hidden_state): {:?}",
outputs.0.shape().dims
);
println!(" output 1 (pooler_output): {:?}", outputs.1.shape().dims);
// Compare outputs
println!("\nComparing model outputs with reference data...");
// Compare last_hidden_state
println!(" Checking last_hidden_state...");
if outputs
.0
.clone()
.all_close(reference_last_hidden.clone(), Some(1e-4), Some(1e-4))
{
println!(" ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ last_hidden_state differs from reference data!");
let diff = outputs.0.clone() - reference_last_hidden.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
}
// Compare pooler_output
println!(" Checking pooler_output...");
if outputs
.1
.clone()
.all_close(reference_pooler.clone(), Some(1e-4), Some(1e-4))
{
println!(" ✓ pooler_output matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ pooler_output differs from reference data!");
let diff = outputs.1.clone() - reference_pooler.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
}
println!("\n========================================");
println!("Model test completed!");
println!("========================================");
}

View File

@@ -1,26 +0,0 @@
[package]
name = "burn-onnx-model-checks-all-minilm-l6-v2"
version = "0.1.0"
edition = "2024"
publish = false
[workspace]
[features]
default = ["tch"]
ndarray = []
tch = []
wgpu = []
metal = []
[dependencies]
burn = { path = "../../../../crates/burn", features = [
"ndarray",
"tch",
"wgpu",
"metal",
] }
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
[build-dependencies]
burn-onnx = { path = "../../../burn-onnx" }

View File

@@ -1,68 +0,0 @@
# all-MiniLM-L6-v2 Model Check
This crate provides testing for the all-MiniLM-L6-v2 sentence transformer model with Burn.
## Model
- `all-MiniLM-L6-v2` - Sentence transformer model from HuggingFace
(https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
## Usage
### 1. Download and prepare the model
```bash
# Using Python directly
python get_model.py
# Or using uv
uv run get_model.py
```
### 2. Build and run the model test
```bash
# Build the model
cargo build
# Run the test
cargo run --release
```
## Directory Structure
```
all-minilm-l6-v2/
├── artifacts/ # Downloaded ONNX model and test data
│ ├── all-minilm-l6-v2_opset16.onnx
│ ├── test_data.pt
│ └── model-python.txt
├── src/
│ └── main.rs # Test runner
├── build.rs # Build script that generates model code
├── get_model.py # Model download and preparation script
├── Cargo.toml
└── README.md
```
## Model Architecture
all-MiniLM-L6-v2 is a sentence transformer model based on Microsoft's MiniLM architecture. It maps
sentences and paragraphs to a 384-dimensional dense vector space and is commonly used for semantic
search, clustering, and similarity tasks.
- **Inputs**:
- `input_ids`: Token IDs (shape: [batch_size, sequence_length])
- `attention_mask`: Attention mask (shape: [batch_size, sequence_length])
- `token_type_ids`: Token type IDs (shape: [batch_size, sequence_length])
- **Outputs**:
- `last_hidden_state`: Sequence of hidden states (shape: [batch_size, sequence_length, 384])
## Notes
- The default sequence length is 128 tokens
- The model has a hidden size of 384
- The model uses ONNX opset 16
- Test data is generated with random inputs for reproducibility (seed=42)
- For sentence embeddings, you typically use mean pooling on the last_hidden_state

View File

@@ -1,32 +0,0 @@
use burn_onnx::ModelGen;
use std::path::Path;
fn main() {
let onnx_path = "artifacts/all-minilm-l6-v2_opset16.onnx";
let test_data_path = "artifacts/test_data.pt";
// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed={}", test_data_path);
println!("cargo:rerun-if-changed=build.rs");
// Check if the ONNX model file exists
if !Path::new(onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!("Please run the following command to download and prepare the model:");
eprintln!(" python get_model.py");
eprintln!();
eprintln!("Or if you prefer using uv:");
eprintln!(" uv run get_model.py");
eprintln!();
eprintln!("This will download the all-MiniLM-L6-v2 model and convert it to ONNX format.");
std::process::exit(1);
}
// Generate the model code from the ONNX file
ModelGen::new()
.input(onnx_path)
.out_dir("model/")
.run_from_script();
}

View File

@@ -1,316 +0,0 @@
#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
# "onnx==1.19.0",
# "onnxruntime>=1.22.0",
# "huggingface-hub>=0.20.0",
# "numpy",
# "torch",
# ]
# ///
import os
import sys
import onnx
from onnx import shape_inference, version_converter
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
def download_minilm_model(output_path):
"""Download all-MiniLM-L6-v2 model from Hugging Face."""
print("Downloading all-MiniLM-L6-v2 model from Hugging Face...")
# Download the ONNX model from Hugging Face
model_path = hf_hub_download(
repo_id="Xenova/all-MiniLM-L6-v2",
filename="onnx/model.onnx",
cache_dir="./artifacts/cache",
)
# Copy to artifacts
import shutil
shutil.copy(model_path, output_path)
if not output_path.exists():
raise FileNotFoundError(f"Failed to download ONNX file to {output_path}")
print(f"✓ Model downloaded to: {output_path}")
def process_model(input_path, output_path, target_opset=16):
"""Load, upgrade opset, and apply shape inference to model."""
print(f"Loading model from {input_path}...")
model = onnx.load(input_path)
# Check and upgrade opset if needed
current_opset = model.opset_import[0].version
if current_opset < target_opset:
print(f"Upgrading opset from {current_opset} to {target_opset}...")
model = version_converter.convert_version(model, target_opset)
# Apply shape inference
print("Applying shape inference...")
model = shape_inference.infer_shapes(model)
# Save processed model
onnx.save(model, output_path)
print(f"✓ Processed model saved to: {output_path}")
return model
def get_input_info(model):
"""Extract input information from ONNX model."""
inputs = []
for input_info in model.graph.input:
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
# Use proper defaults for sentence transformers
if (
"input_ids" in input_info.name
or "attention_mask" in input_info.name
or "token_type_ids" in input_info.name
):
# Default sequence length for all text inputs
shape.append(1 if len(shape) == 0 else 128)
else:
shape.append(1) # Default to 1 for other dynamic dimensions
inputs.append(
{
"name": input_info.name,
"shape": shape,
"dtype": input_info.type.tensor_type.elem_type,
}
)
return inputs
def generate_test_data(model_path, output_dir):
"""Generate test input/output data and save as PyTorch tensors."""
import torch
import onnxruntime as ort
print("\nGenerating test data...")
# Load model to get input shapes
model = onnx.load(model_path)
input_infos = get_input_info(model)
print(f" Model has {len(input_infos)} inputs:")
for info in input_infos:
print(f" - {info['name']}: shape={info['shape']}, dtype={info['dtype']}")
# Create reproducible test inputs
np.random.seed(42)
test_inputs = {}
for info in input_infos:
if info["dtype"] == onnx.TensorProto.INT64:
# Handle different integer inputs appropriately
if "attention_mask" in info["name"]:
# Attention mask should be 1 for valid tokens, 0 for padding
# For testing, use all 1s (all valid tokens)
test_input = np.ones(info["shape"], dtype=np.int64)
elif "token_type_ids" in info["name"]:
# Token type IDs should be 0 or 1 (typically 0 for single-sequence tasks)
test_input = np.zeros(info["shape"], dtype=np.int64)
elif "input_ids" in info["name"]:
# For input_ids, use random integers in vocabulary range
test_input = np.random.randint(0, 30522, size=info["shape"], dtype=np.int64)
else:
# For other INT64 inputs, use random integers in vocabulary range
test_input = np.random.randint(0, 30522, size=info["shape"], dtype=np.int64)
else:
# For float inputs, use random floats
test_input = np.random.rand(*info["shape"]).astype(np.float32)
test_inputs[info["name"]] = test_input
# Run inference to get output
session = ort.InferenceSession(model_path)
outputs = session.run(None, test_inputs)
# Save in a format that's easier to load in Rust
# For all-MiniLM-L6-v2, we expect:
# - Inputs: input_ids, attention_mask, token_type_ids
# - Outputs: last_hidden_state (3D)
# Compute mean pooled embeddings (what sentence-transformers uses)
last_hidden_state = outputs[0]
attention_mask_np = test_inputs.get("attention_mask")
# Mean pooling - take attention mask into account for correct averaging
input_mask_expanded = np.expand_dims(attention_mask_np, axis=-1).astype(np.float32)
sum_embeddings = np.sum(last_hidden_state * input_mask_expanded, axis=1)
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
pooled_embeddings = sum_embeddings / sum_mask
# Create a more structured format for Rust
test_data = {
"input_ids": torch.from_numpy(
test_inputs.get(
"input_ids", test_inputs.get("inputs.0", list(test_inputs.values())[0])
)
),
"attention_mask": torch.from_numpy(
test_inputs.get(
"attention_mask",
test_inputs.get("inputs.1", list(test_inputs.values())[1]),
)
),
"token_type_ids": torch.from_numpy(
test_inputs.get(
"token_type_ids",
test_inputs.get("inputs.2", np.zeros((1, 128), dtype=np.int64)),
)
),
"last_hidden_state": torch.from_numpy(outputs[0]),
"pooled_embeddings": torch.from_numpy(pooled_embeddings),
}
test_data_path = Path(output_dir) / "test_data.pt"
torch.save(test_data, test_data_path)
print(f" ✓ Test data saved to: {test_data_path}")
print(f" Input shapes:")
print(f" input_ids: {test_data['input_ids'].shape}")
print(f" attention_mask: {test_data['attention_mask'].shape}")
print(f" token_type_ids: {test_data['token_type_ids'].shape}")
print(f" Output shapes:")
print(f" last_hidden_state: {test_data['last_hidden_state'].shape}")
print(f" pooled_embeddings: {test_data['pooled_embeddings'].shape}")
def save_model_info(model_path, output_dir):
"""Save model structure information to a text file."""
print("\nSaving model information...")
model = onnx.load(model_path)
info_path = Path(output_dir) / "model-python.txt"
with open(info_path, "w") as f:
f.write("all-MiniLM-L6-v2 Model Information\n")
f.write("=" * 60 + "\n\n")
# Input information
f.write("Inputs:\n")
for input_info in model.graph.input:
f.write(f" - {input_info.name}\n")
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("dynamic")
f.write(f" Shape: {shape}\n")
f.write(
f" Type: {onnx.TensorProto.DataType.Name(input_info.type.tensor_type.elem_type)}\n"
)
# Output information
f.write("\nOutputs:\n")
for output_info in model.graph.output:
f.write(f" - {output_info.name}\n")
shape = []
for dim in output_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("dynamic")
f.write(f" Shape: {shape}\n")
f.write(
f" Type: {onnx.TensorProto.DataType.Name(output_info.type.tensor_type.elem_type)}\n"
)
# Model statistics
f.write(f"\nModel Statistics:\n")
f.write(f" Opset version: {model.opset_import[0].version}\n")
f.write(f" Number of nodes: {len(model.graph.node)}\n")
f.write(f" Number of initializers: {len(model.graph.initializer)}\n")
# Node types summary
node_types = {}
for node in model.graph.node:
op_type = node.op_type
node_types[op_type] = node_types.get(op_type, 0) + 1
f.write(f"\nNode types:\n")
for op_type, count in sorted(node_types.items()):
f.write(f" {op_type}: {count}\n")
print(f" ✓ Model info saved to: {info_path}")
def main():
print("=" * 60)
print("all-MiniLM-L6-v2 Model Preparation Tool")
print("=" * 60)
# Setup paths
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
original_path = artifacts_dir / "all-minilm-l6-v2.onnx"
processed_path = artifacts_dir / "all-minilm-l6-v2_opset16.onnx"
test_data_path = artifacts_dir / "test_data.pt"
model_info_path = artifacts_dir / "model-python.txt"
# Check if we already have everything
if processed_path.exists() and test_data_path.exists() and model_info_path.exists():
print(f"\n✓ All files already exist:")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print(f" Model info: {model_info_path}")
print("\nNothing to do!")
return
# Download model if needed
if not original_path.exists() and not processed_path.exists():
print("\nStep 1: Downloading all-MiniLM-L6-v2 model...")
download_minilm_model(original_path)
# Process model if needed
if not processed_path.exists():
print("\nStep 2: Processing model...")
process_model(original_path, processed_path, target_opset=16)
# Clean up original if we have the processed version
if original_path.exists() and processed_path.exists():
original_path.unlink()
# Generate test data if needed
if not test_data_path.exists():
print("\nStep 3: Generating test data...")
generate_test_data(processed_path, artifacts_dir)
# Save model info if needed
if not model_info_path.exists():
print("\nStep 4: Saving model information...")
save_model_info(processed_path, artifacts_dir)
print("\n" + "=" * 60)
print("✓ all-MiniLM-L6-v2 model preparation completed!")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print(f" Model info: {model_info_path}")
print("=" * 60)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n⚠ Operation cancelled by user.")
sys.exit(1)
except Exception as e:
print(f"\n✗ Error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,284 +0,0 @@
extern crate alloc;
use burn::module::{Initializer, Param};
use burn::prelude::*;
use burn_store::{ModuleSnapshot, PytorchStore};
use std::path::Path;
use std::time::Instant;
#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;
#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;
#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;
#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;
// Import the generated model code as a module
pub mod all_minilm_l6_v2 {
include!(concat!(
env!("OUT_DIR"),
"/model/all-minilm-l6-v2_opset16.rs"
));
}
#[derive(Debug, Module)]
struct TestData<B: Backend> {
input_ids: Param<Tensor<B, 2, Int>>,
attention_mask: Param<Tensor<B, 2, Int>>,
token_type_ids: Param<Tensor<B, 2, Int>>,
last_hidden_state: Param<Tensor<B, 3>>,
pooled_embeddings: Param<Tensor<B, 2>>,
}
impl<B: Backend> TestData<B> {
fn new(device: &B::Device) -> Self {
use burn::module::ParamId;
// Initialize with correct shapes matching the test data
// Note: Initializer only works for float tensors, Int tensors need manual init
Self {
input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
token_type_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 128], device)),
last_hidden_state: Initializer::Zeros.init([1, 128, 384], device),
pooled_embeddings: Initializer::Zeros.init([1, 384], device),
}
}
}
/// Apply mean pooling to get sentence embeddings
fn mean_pool<B: Backend>(
last_hidden_state: Tensor<B, 3>,
attention_mask: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
// Convert attention_mask to float and expand dimensions to match hidden_state
let attention_mask_float = attention_mask.float().unsqueeze_dim::<3>(2);
// Multiply hidden states by attention mask
let masked_embeddings = last_hidden_state * attention_mask_float.clone();
// Sum along sequence dimension (dim 1)
let sum_embeddings = masked_embeddings.sum_dim(1);
// Sum attention mask to get count of non-padding tokens
let sum_mask = attention_mask_float.sum_dim(1).clamp_min(1e-9);
// Divide to get mean - result is [batch, 1, hidden]
let pooled = sum_embeddings / sum_mask;
// Get the shape to reshape to [batch, hidden]
let shape = pooled.shape();
let batch_size = shape.dims[0];
let hidden_size = shape.dims[2];
pooled.reshape([batch_size, hidden_size])
}
fn main() {
println!("========================================");
println!("all-MiniLM-L6-v2 Burn Model Test");
println!("========================================\n");
// Check if artifacts exist
let artifacts_dir = Path::new("artifacts");
if !artifacts_dir.exists() {
eprintln!("Error: artifacts directory not found!");
eprintln!("Please run get_model.py first to download the model and test data.");
std::process::exit(1);
}
// Initialize the model (using default which includes the converted weights)
println!("Initializing all-MiniLM-L6-v2 model...");
let start = Instant::now();
let device = Default::default();
let model: all_minilm_l6_v2::Model<MyBackend> = all_minilm_l6_v2::Model::default();
let init_time = start.elapsed();
println!(" Model initialized in {:.2?}", init_time);
// Save model structure to file
println!("\nSaving model structure to artifacts/model.txt...");
let model_str = format!("{}", model);
std::fs::write("artifacts/model.txt", &model_str)
.expect("Failed to write model structure to file");
println!(" Model structure saved");
// Load test data from PyTorch file
println!("\nLoading test data from artifacts/test_data.pt...");
let start = Instant::now();
let mut test_data = TestData::<MyBackend>::new(&device);
let mut store = PytorchStore::from_file("artifacts/test_data.pt");
test_data.load_from(&mut store).expect("Failed to load test data");
let load_time = start.elapsed();
println!(" Data loaded in {:.2?}", load_time);
// Get the input tensors from test data
let input_ids = test_data.input_ids.val();
let attention_mask = test_data.attention_mask.val();
let token_type_ids = test_data.token_type_ids.val();
let input_ids_shape = input_ids.shape();
let attention_mask_shape = attention_mask.shape();
let token_type_ids_shape = token_type_ids.shape();
println!(" Loaded input_ids with shape: {:?}", input_ids_shape.dims);
println!(
" Loaded attention_mask with shape: {:?}",
attention_mask_shape.dims
);
println!(
" Loaded token_type_ids with shape: {:?}",
token_type_ids_shape.dims
);
// Get the reference outputs from test data
let reference_last_hidden_state = test_data.last_hidden_state.val();
let reference_pooled_embeddings = test_data.pooled_embeddings.val();
let ref_last_hidden_shape = reference_last_hidden_state.shape();
let ref_pooled_shape = reference_pooled_embeddings.shape();
println!(
" Loaded reference last_hidden_state with shape: {:?}",
ref_last_hidden_shape.dims
);
println!(
" Loaded reference pooled_embeddings with shape: {:?}",
ref_pooled_shape.dims
);
// Run inference with the loaded input
println!("\nRunning model inference with test input...");
let start = Instant::now();
let last_hidden_state = model.forward(
input_ids.clone(),
attention_mask.clone(),
token_type_ids.clone(),
);
let inference_time = start.elapsed();
println!(" Inference completed in {:.2?}", inference_time);
// Compute pooled embeddings (mean pooling)
println!("\nComputing pooled embeddings (mean pooling)...");
let start = Instant::now();
let pooled_embeddings = mean_pool(last_hidden_state.clone(), attention_mask.clone());
let pooling_time = start.elapsed();
println!(" Pooling completed in {:.2?}", pooling_time);
// Display output shapes
let last_hidden_shape = last_hidden_state.shape();
let pooled_shape = pooled_embeddings.shape();
println!("\n Model output shapes:");
println!(" last_hidden_state: {:?}", last_hidden_shape.dims);
println!(" pooled_embeddings: {:?}", pooled_shape.dims);
// Verify expected output shapes match
if last_hidden_shape.dims == ref_last_hidden_shape.dims {
println!(
" ✓ last_hidden_state shape matches expected: {:?}",
ref_last_hidden_shape.dims
);
} else {
println!(
" ⚠ Warning: Expected last_hidden_state shape {:?}, got {:?}",
ref_last_hidden_shape.dims, last_hidden_shape.dims
);
}
if pooled_shape.dims == ref_pooled_shape.dims {
println!(
" ✓ pooled_embeddings shape matches expected: {:?}",
ref_pooled_shape.dims
);
} else {
println!(
" ⚠ Warning: Expected pooled_embeddings shape {:?}, got {:?}",
ref_pooled_shape.dims, pooled_shape.dims
);
}
// Compare outputs
println!("\nComparing model outputs with reference data...");
// Check if last_hidden_state is close
println!("\n Checking last_hidden_state:");
if last_hidden_state.clone().all_close(
reference_last_hidden_state.clone(),
Some(1e-4),
Some(1e-4),
) {
println!(" ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ last_hidden_state differs from reference data!");
// Calculate and display the difference statistics
let diff = last_hidden_state.clone() - reference_last_hidden_state.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = last_hidden_state.clone().flatten::<1>(0, 2);
let reference_flat = reference_last_hidden_state.clone().flatten::<1>(0, 2);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
// Check if pooled_embeddings is close
println!("\n Checking pooled_embeddings:");
if pooled_embeddings.clone().all_close(
reference_pooled_embeddings.clone(),
Some(1e-4),
Some(1e-4),
) {
println!(" ✓ pooled_embeddings matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ pooled_embeddings differs from reference data!");
// Calculate and display the difference statistics
let diff = pooled_embeddings.clone() - reference_pooled_embeddings.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = pooled_embeddings.clone().flatten::<1>(0, 1);
let reference_flat = reference_pooled_embeddings.clone().flatten::<1>(0, 1);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
println!("\n========================================");
println!("Model test completed!");
println!("========================================");
}

View File

@@ -1,26 +0,0 @@
[package]
name = "burn-onnx-model-checks-clip-vit-b-32-text"
version = "0.1.0"
edition = "2024"
publish = false
[workspace]
[features]
default = ["tch"]
ndarray = []
tch = []
wgpu = []
metal = []
[dependencies]
burn = { path = "../../../../crates/burn", features = [
"ndarray",
"tch",
"wgpu",
"metal",
] }
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
[build-dependencies]
burn-onnx = { path = "../../../burn-onnx" }

View File

@@ -1,33 +0,0 @@
use burn_onnx::ModelGen;
use std::path::Path;
fn main() {
let onnx_path = "artifacts/clip-vit-b-32-text_opset16.onnx";
let test_data_path = "artifacts/test_data.pt";
// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed={}", test_data_path);
println!("cargo:rerun-if-changed=build.rs");
// Check if the ONNX model file exists
if !Path::new(onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!("Please run the following command to download and prepare the model:");
eprintln!(" python get_model.py");
eprintln!();
eprintln!("Or if you prefer using uv:");
eprintln!(" uv run get_model.py");
eprintln!();
eprintln!("This will download the CLIP ViT-B-32-text model and convert it to ONNX format.");
std::process::exit(1);
}
// Generate the model code from the ONNX file
// Use double precision to handle large Int64 constants in CLIP
ModelGen::new()
.input(onnx_path)
.out_dir("model/")
.run_from_script();
}

View File

@@ -1,288 +0,0 @@
#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
# "onnx==1.19.0",
# "onnxruntime>=1.22.0",
# "huggingface-hub>=0.20.0",
# "numpy",
# "torch",
# ]
# ///
import os
import sys
import onnx
from onnx import shape_inference, version_converter
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
def download_clip_model(output_path):
"""Download CLIP ViT-B-32-text model from Hugging Face."""
print("Downloading CLIP ViT-B-32-text model from Hugging Face...")
# Download the ONNX model from Hugging Face
model_path = hf_hub_download(
repo_id="Qdrant/clip-ViT-B-32-text",
filename="model.onnx",
cache_dir="./artifacts/cache",
)
# Copy to artifacts
import shutil
shutil.copy(model_path, output_path)
if not output_path.exists():
raise FileNotFoundError(f"Failed to download ONNX file to {output_path}")
print(f"✓ Model downloaded to: {output_path}")
def process_model(input_path, output_path, target_opset=16):
"""Load, upgrade opset, and apply shape inference to model."""
print(f"Loading model from {input_path}...")
model = onnx.load(input_path)
# Check and upgrade opset if needed
current_opset = model.opset_import[0].version
if current_opset < target_opset:
print(f"Upgrading opset from {current_opset} to {target_opset}...")
model = version_converter.convert_version(model, target_opset)
# Apply shape inference
print("Applying shape inference...")
model = shape_inference.infer_shapes(model)
# Save processed model
onnx.save(model, output_path)
print(f"✓ Processed model saved to: {output_path}")
return model
def get_input_info(model):
"""Extract input information from ONNX model."""
inputs = []
for input_info in model.graph.input:
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
# Use proper defaults for CLIP model
if (
"input_ids" in input_info.name
or "attention_mask" in input_info.name
):
# CLIP uses sequence length of 77
shape.append(1 if len(shape) == 0 else 77)
else:
shape.append(1) # Default to 1 for other dynamic dimensions
inputs.append(
{
"name": input_info.name,
"shape": shape,
"dtype": input_info.type.tensor_type.elem_type,
}
)
return inputs
def generate_test_data(model_path, output_dir):
"""Generate test input/output data and save as PyTorch tensors."""
import torch
import onnxruntime as ort
print("\nGenerating test data...")
# Load model to get input shapes
model = onnx.load(model_path)
input_infos = get_input_info(model)
print(f" Model has {len(input_infos)} inputs:")
for info in input_infos:
print(f" - {info['name']}: shape={info['shape']}, dtype={info['dtype']}")
# Create reproducible test inputs
np.random.seed(42)
test_inputs = {}
for info in input_infos:
if info["dtype"] == onnx.TensorProto.INT64:
# For INT64 inputs (like input_ids), use random integers
test_input = np.random.randint(0, 1000, size=info["shape"], dtype=np.int64)
else:
# For float inputs, use random floats
test_input = np.random.rand(*info["shape"]).astype(np.float32)
test_inputs[info["name"]] = test_input
# Run inference to get output
session = ort.InferenceSession(model_path)
outputs = session.run(None, test_inputs)
# Save in a format that's easier to load in Rust
# For CLIP, we expect:
# - Inputs: input_ids, attention_mask
# - Outputs: text_embeds (2D), last_hidden_state (3D)
# Create a more structured format for Rust
test_data = {
"input_ids": torch.from_numpy(
test_inputs.get(
"input_ids", test_inputs.get("inputs.0", list(test_inputs.values())[0])
)
),
"attention_mask": torch.from_numpy(
test_inputs.get(
"attention_mask",
test_inputs.get("inputs.1", list(test_inputs.values())[1]),
)
),
"text_embeds": torch.from_numpy(outputs[0]),
"last_hidden_state": torch.from_numpy(outputs[1])
if len(outputs) > 1
else torch.zeros(1, 77, 512),
}
test_data_path = Path(output_dir) / "test_data.pt"
torch.save(test_data, test_data_path)
print(f" ✓ Test data saved to: {test_data_path}")
print(f" Input shapes:")
print(f" input_ids: {test_data['input_ids'].shape}")
print(f" attention_mask: {test_data['attention_mask'].shape}")
print(f" Output shapes:")
print(f" text_embeds: {test_data['text_embeds'].shape}")
print(f" last_hidden_state: {test_data['last_hidden_state'].shape}")
def save_model_info(model_path, output_dir):
"""Save model structure information to a text file."""
print("\nSaving model information...")
model = onnx.load(model_path)
info_path = Path(output_dir) / "model-python.txt"
with open(info_path, "w") as f:
f.write("CLIP ViT-B-32-text Model Information\n")
f.write("=" * 60 + "\n\n")
# Input information
f.write("Inputs:\n")
for input_info in model.graph.input:
f.write(f" - {input_info.name}\n")
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("dynamic")
f.write(f" Shape: {shape}\n")
f.write(
f" Type: {onnx.TensorProto.DataType.Name(input_info.type.tensor_type.elem_type)}\n"
)
# Output information
f.write("\nOutputs:\n")
for output_info in model.graph.output:
f.write(f" - {output_info.name}\n")
shape = []
for dim in output_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("dynamic")
f.write(f" Shape: {shape}\n")
f.write(
f" Type: {onnx.TensorProto.DataType.Name(output_info.type.tensor_type.elem_type)}\n"
)
# Model statistics
f.write(f"\nModel Statistics:\n")
f.write(f" Opset version: {model.opset_import[0].version}\n")
f.write(f" Number of nodes: {len(model.graph.node)}\n")
f.write(f" Number of initializers: {len(model.graph.initializer)}\n")
# Node types summary
node_types = {}
for node in model.graph.node:
op_type = node.op_type
node_types[op_type] = node_types.get(op_type, 0) + 1
f.write(f"\nNode types:\n")
for op_type, count in sorted(node_types.items()):
f.write(f" {op_type}: {count}\n")
print(f" ✓ Model info saved to: {info_path}")
def main():
print("=" * 60)
print("CLIP ViT-B-32-text Model Preparation Tool")
print("=" * 60)
# Setup paths
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
original_path = artifacts_dir / "clip-vit-b-32-text.onnx"
processed_path = artifacts_dir / "clip-vit-b-32-text_opset16.onnx"
test_data_path = artifacts_dir / "test_data.pt"
model_info_path = artifacts_dir / "model-python.txt"
# Check if we already have everything
if processed_path.exists() and test_data_path.exists() and model_info_path.exists():
print(f"\n✓ All files already exist:")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print(f" Model info: {model_info_path}")
print("\nNothing to do!")
return
# Download model if needed
if not original_path.exists() and not processed_path.exists():
print("\nStep 1: Downloading CLIP model...")
download_clip_model(original_path)
# Process model if needed
if not processed_path.exists():
print("\nStep 2: Processing model...")
process_model(original_path, processed_path, target_opset=16)
# Clean up original if we have the processed version
if original_path.exists() and processed_path.exists():
original_path.unlink()
# Generate test data if needed
if not test_data_path.exists():
print("\nStep 3: Generating test data...")
generate_test_data(processed_path, artifacts_dir)
# Save model info if needed
if not model_info_path.exists():
print("\nStep 4: Saving model information...")
save_model_info(processed_path, artifacts_dir)
print("\n" + "=" * 60)
print("✓ CLIP model preparation completed!")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print(f" Model info: {model_info_path}")
print("=" * 60)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n⚠ Operation cancelled by user.")
sys.exit(1)
except Exception as e:
print(f"\n✗ Error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,236 +0,0 @@
extern crate alloc;
use burn::module::{Initializer, Param};
use burn::prelude::*;
use burn_store::{ModuleSnapshot, PytorchStore};
use std::path::Path;
use std::time::Instant;
#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;
#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;
#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;
#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;
// Import the generated model code as a module
pub mod clip_vit_b_32_text {
include!(concat!(
env!("OUT_DIR"),
"/model/clip-vit-b-32-text_opset16.rs"
));
}
#[derive(Debug, Module)]
struct TestData<B: Backend> {
input_ids: Param<Tensor<B, 2, Int>>,
attention_mask: Param<Tensor<B, 2, Int>>,
text_embeds: Param<Tensor<B, 2>>,
last_hidden_state: Param<Tensor<B, 3>>,
}
impl<B: Backend> TestData<B> {
fn new(device: &B::Device) -> Self {
use burn::module::ParamId;
// CLIP ViT-B-32 text: sequence_length=77, embed_dim=512
// Note: Initializer only works for float tensors, Int tensors need manual init
Self {
input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 77], device)),
attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 77], device)),
text_embeds: Initializer::Zeros.init([1, 512], device),
last_hidden_state: Initializer::Zeros.init([1, 77, 512], device),
}
}
}
fn main() {
println!("========================================");
println!("CLIP ViT-B-32-text Burn Model Test");
println!("========================================\n");
// Check if artifacts exist
let artifacts_dir = Path::new("artifacts");
if !artifacts_dir.exists() {
eprintln!("Error: artifacts directory not found!");
eprintln!("Please run get_model.py first to download the model and test data.");
std::process::exit(1);
}
// Initialize the model (using default which includes the converted weights)
println!("Initializing CLIP model...");
let start = Instant::now();
let device = Default::default();
let model: clip_vit_b_32_text::Model<MyBackend> = clip_vit_b_32_text::Model::default();
let init_time = start.elapsed();
println!(" Model initialized in {:.2?}", init_time);
// Save model structure to file
println!("\nSaving model structure to artifacts/model.txt...");
let model_str = format!("{}", model);
std::fs::write("artifacts/model.txt", &model_str)
.expect("Failed to write model structure to file");
println!(" Model structure saved");
// Load test data from PyTorch file
println!("\nLoading test data from artifacts/test_data.pt...");
let start = Instant::now();
let mut test_data = TestData::<MyBackend>::new(&device);
let mut store = PytorchStore::from_file("artifacts/test_data.pt");
test_data.load_from(&mut store).expect("Failed to load test data");
let load_time = start.elapsed();
println!(" Data loaded in {:.2?}", load_time);
// Get the input tensors from test data
let input_ids = test_data.input_ids.val();
let attention_mask = test_data.attention_mask.val();
let input_ids_shape = input_ids.shape();
let attention_mask_shape = attention_mask.shape();
println!(" Loaded input_ids with shape: {:?}", input_ids_shape.dims);
println!(
" Loaded attention_mask with shape: {:?}",
attention_mask_shape.dims
);
// Get the reference outputs from test data
let reference_text_embeds = test_data.text_embeds.val();
let reference_last_hidden_state = test_data.last_hidden_state.val();
let ref_text_embeds_shape = reference_text_embeds.shape();
let ref_last_hidden_shape = reference_last_hidden_state.shape();
println!(
" Loaded reference text_embeds with shape: {:?}",
ref_text_embeds_shape.dims
);
println!(
" Loaded reference last_hidden_state with shape: {:?}",
ref_last_hidden_shape.dims
);
// Run inference with the loaded input
println!("\nRunning model inference with test input...");
let start = Instant::now();
let (text_embeds, last_hidden_state) = model.forward(input_ids, attention_mask);
let inference_time = start.elapsed();
println!(" Inference completed in {:.2?}", inference_time);
// Display output shapes
let text_embeds_shape = text_embeds.shape();
let last_hidden_shape = last_hidden_state.shape();
println!("\n Model output shapes:");
println!(" text_embeds: {:?}", text_embeds_shape.dims);
println!(" last_hidden_state: {:?}", last_hidden_shape.dims);
// Verify expected output shapes match
if text_embeds_shape.dims == ref_text_embeds_shape.dims {
println!(
" ✓ text_embeds shape matches expected: {:?}",
ref_text_embeds_shape.dims
);
} else {
println!(
" ⚠ Warning: Expected text_embeds shape {:?}, got {:?}",
ref_text_embeds_shape.dims, text_embeds_shape.dims
);
}
if last_hidden_shape.dims == ref_last_hidden_shape.dims {
println!(
" ✓ last_hidden_state shape matches expected: {:?}",
ref_last_hidden_shape.dims
);
} else {
println!(
" ⚠ Warning: Expected last_hidden_state shape {:?}, got {:?}",
ref_last_hidden_shape.dims, last_hidden_shape.dims
);
}
// Compare outputs
println!("\nComparing model outputs with reference data...");
// Check if text_embeds are close
println!("\n Checking text_embeds:");
if text_embeds
.clone()
.all_close(reference_text_embeds.clone(), Some(1e-4), Some(1e-4))
{
println!(" ✓ text_embeds matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ text_embeds differs from reference data!");
// Calculate and display the difference statistics
let diff = text_embeds.clone() - reference_text_embeds.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = text_embeds.clone().flatten::<1>(0, 1);
let reference_flat = reference_text_embeds.clone().flatten::<1>(0, 1);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
// Check if last_hidden_state is close
println!("\n Checking last_hidden_state:");
if last_hidden_state.clone().all_close(
reference_last_hidden_state.clone(),
Some(1e-4),
Some(1e-4),
) {
println!(" ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ last_hidden_state differs from reference data!");
// Calculate and display the difference statistics
let diff = last_hidden_state.clone() - reference_last_hidden_state.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = last_hidden_state.clone().flatten::<1>(0, 2);
let reference_flat = reference_last_hidden_state.clone().flatten::<1>(0, 2);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
println!("\n========================================");
println!("Model test completed!");
println!("========================================");
}

View File

@@ -1,26 +0,0 @@
[package]
name = "burn-onnx-model-checks-clip-vit-b-32-vision"
version = "0.1.0"
edition = "2024"
publish = false
[workspace]
[features]
default = ["tch"]
ndarray = []
tch = []
wgpu = []
metal = []
[dependencies]
burn = { path = "../../../../crates/burn", features = [
"ndarray",
"tch",
"wgpu",
"metal",
] }
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
[build-dependencies]
burn-onnx = { path = "../../../burn-onnx" }

View File

@@ -1,35 +0,0 @@
use burn_onnx::ModelGen;
use std::path::Path;
fn main() {
let onnx_path = "artifacts/clip-vit-b-32-vision_opset16.onnx";
let test_data_path = "artifacts/test_data.pt";
// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed={}", test_data_path);
println!("cargo:rerun-if-changed=build.rs");
// Check if the ONNX model file exists
if !Path::new(onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!("Please run the following command to download and prepare the model:");
eprintln!(" python get_model.py");
eprintln!();
eprintln!("Or if you prefer using uv:");
eprintln!(" uv run get_model.py");
eprintln!();
eprintln!(
"This will download the CLIP ViT-B-32-vision model and convert it to ONNX format."
);
std::process::exit(1);
}
// Generate the model code from the ONNX file
// Use double precision to handle large Int64 constants in CLIP
ModelGen::new()
.input(onnx_path)
.out_dir("model/")
.run_from_script();
}

View File

@@ -1,279 +0,0 @@
#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
# "onnx==1.19.0",
# "onnxruntime>=1.22.0",
# "huggingface-hub>=0.20.0",
# "numpy",
# "torch",
# ]
# ///
import os
import sys
import onnx
from onnx import shape_inference, version_converter
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
def download_clip_model(output_path):
"""Download CLIP ViT-B-32-vision model from Hugging Face."""
print("Downloading CLIP ViT-B-32-vision model from Hugging Face...")
# Download the ONNX model from Hugging Face
model_path = hf_hub_download(
repo_id="Qdrant/clip-ViT-B-32-vision",
filename="model.onnx",
cache_dir="./artifacts/cache",
)
# Copy to artifacts
import shutil
shutil.copy(model_path, output_path)
if not output_path.exists():
raise FileNotFoundError(f"Failed to download ONNX file to {output_path}")
print(f"✓ Model downloaded to: {output_path}")
def process_model(input_path, output_path, target_opset=16):
"""Load, upgrade opset, and apply shape inference to model."""
print(f"Loading model from {input_path}...")
model = onnx.load(input_path)
# Check and upgrade opset if needed
current_opset = model.opset_import[0].version
if current_opset < target_opset:
print(f"Upgrading opset from {current_opset} to {target_opset}...")
model = version_converter.convert_version(model, target_opset)
# Apply shape inference
print("Applying shape inference...")
model = shape_inference.infer_shapes(model)
# Save processed model
onnx.save(model, output_path)
print(f"✓ Processed model saved to: {output_path}")
return model
def get_input_info(model):
"""Extract input information from ONNX model."""
inputs = []
for input_info in model.graph.input:
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
# Use proper defaults for CLIP vision model
if "pixel_values" in input_info.name:
# CLIP vision uses [batch, channels, height, width]
if len(shape) == 0:
shape.append(1) # batch
elif len(shape) == 1:
shape.append(3) # channels
elif len(shape) == 2:
shape.append(224) # height
elif len(shape) == 3:
shape.append(224) # width
else:
shape.append(1) # Default to 1 for other dynamic dimensions
inputs.append(
{
"name": input_info.name,
"shape": shape,
"dtype": input_info.type.tensor_type.elem_type,
}
)
return inputs
def generate_test_data(model_path, output_dir):
"""Generate test input/output data and save as PyTorch tensors."""
import torch
import onnxruntime as ort
print("\nGenerating test data...")
# Load model to get input shapes
model = onnx.load(model_path)
input_infos = get_input_info(model)
print(f" Model has {len(input_infos)} inputs:")
for info in input_infos:
print(f" - {info['name']}: shape={info['shape']}, dtype={info['dtype']}")
# Create reproducible test inputs
np.random.seed(42)
test_inputs = {}
for info in input_infos:
if info["dtype"] == onnx.TensorProto.INT64:
# For INT64 inputs, use random integers
test_input = np.random.randint(0, 1000, size=info["shape"], dtype=np.int64)
else:
# For float inputs (like pixel_values), use random floats
test_input = np.random.rand(*info["shape"]).astype(np.float32)
test_inputs[info["name"]] = test_input
# Run inference to get output
session = ort.InferenceSession(model_path)
outputs = session.run(None, test_inputs)
# Save in a format that's easier to load in Rust
# For CLIP vision, we expect:
# - Inputs: pixel_values
# - Outputs: image_embeds (2D)
# Create a more structured format for Rust
test_data = {
"pixel_values": torch.from_numpy(
test_inputs.get("pixel_values", list(test_inputs.values())[0])
),
"image_embeds": torch.from_numpy(outputs[0]),
}
test_data_path = Path(output_dir) / "test_data.pt"
torch.save(test_data, test_data_path)
print(f" ✓ Test data saved to: {test_data_path}")
print(f" Input shapes:")
print(f" pixel_values: {test_data['pixel_values'].shape}")
print(f" Output shapes:")
print(f" image_embeds: {test_data['image_embeds'].shape}")
def save_model_info(model_path, output_dir):
"""Save model structure information to a text file."""
print("\nSaving model information...")
model = onnx.load(model_path)
info_path = Path(output_dir) / "model-python.txt"
with open(info_path, "w") as f:
f.write("CLIP ViT-B-32-vision Model Information\n")
f.write("=" * 60 + "\n\n")
# Input information
f.write("Inputs:\n")
for input_info in model.graph.input:
f.write(f" - {input_info.name}\n")
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("dynamic")
f.write(f" Shape: {shape}\n")
f.write(
f" Type: {onnx.TensorProto.DataType.Name(input_info.type.tensor_type.elem_type)}\n"
)
# Output information
f.write("\nOutputs:\n")
for output_info in model.graph.output:
f.write(f" - {output_info.name}\n")
shape = []
for dim in output_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("dynamic")
f.write(f" Shape: {shape}\n")
f.write(
f" Type: {onnx.TensorProto.DataType.Name(output_info.type.tensor_type.elem_type)}\n"
)
# Model statistics
f.write(f"\nModel Statistics:\n")
f.write(f" Opset version: {model.opset_import[0].version}\n")
f.write(f" Number of nodes: {len(model.graph.node)}\n")
f.write(f" Number of initializers: {len(model.graph.initializer)}\n")
# Node types summary
node_types = {}
for node in model.graph.node:
op_type = node.op_type
node_types[op_type] = node_types.get(op_type, 0) + 1
f.write(f"\nNode types:\n")
for op_type, count in sorted(node_types.items()):
f.write(f" {op_type}: {count}\n")
print(f" ✓ Model info saved to: {info_path}")
def main():
print("=" * 60)
print("CLIP ViT-B-32-vision Model Preparation Tool")
print("=" * 60)
# Setup paths
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
original_path = artifacts_dir / "clip-vit-b-32-vision.onnx"
processed_path = artifacts_dir / "clip-vit-b-32-vision_opset16.onnx"
test_data_path = artifacts_dir / "test_data.pt"
model_info_path = artifacts_dir / "model-python.txt"
# Check if we already have everything
if processed_path.exists() and test_data_path.exists() and model_info_path.exists():
print(f"\n✓ All files already exist:")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print(f" Model info: {model_info_path}")
print("\nNothing to do!")
return
# Download model if needed
if not original_path.exists() and not processed_path.exists():
print("\nStep 1: Downloading CLIP model...")
download_clip_model(original_path)
# Process model if needed
if not processed_path.exists():
print("\nStep 2: Processing model...")
process_model(original_path, processed_path, target_opset=16)
# Clean up original if we have the processed version
if original_path.exists() and processed_path.exists():
original_path.unlink()
# Generate test data if needed
if not test_data_path.exists():
print("\nStep 3: Generating test data...")
generate_test_data(processed_path, artifacts_dir)
# Save model info if needed
if not model_info_path.exists():
print("\nStep 4: Saving model information...")
save_model_info(processed_path, artifacts_dir)
print("\n" + "=" * 60)
print("✓ CLIP model preparation completed!")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print(f" Model info: {model_info_path}")
print("=" * 60)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n⚠ Operation cancelled by user.")
sys.exit(1)
except Exception as e:
print(f"\n✗ Error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,176 +0,0 @@
extern crate alloc;
use burn::module::{Initializer, Param};
use burn::prelude::*;
use burn_store::{ModuleSnapshot, PytorchStore};
use std::path::Path;
use std::time::Instant;
#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;
#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;
#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;
#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;
// Import the generated model code as a module
pub mod clip_vit_b_32_vision {
include!(concat!(
env!("OUT_DIR"),
"/model/clip-vit-b-32-vision_opset16.rs"
));
}
#[derive(Debug, Module)]
struct TestData<B: Backend> {
pixel_values: Param<Tensor<B, 4>>,
image_embeds: Param<Tensor<B, 2>>,
}
impl<B: Backend> TestData<B> {
fn new(device: &B::Device) -> Self {
// CLIP ViT-B-32 vision: image_size=224, embed_dim=512
Self {
pixel_values: Initializer::Zeros.init([1, 3, 224, 224], device),
image_embeds: Initializer::Zeros.init([1, 512], device),
}
}
}
fn main() {
println!("========================================");
println!("CLIP ViT-B-32-vision Burn Model Test");
println!("========================================\n");
// Check if artifacts exist
let artifacts_dir = Path::new("artifacts");
if !artifacts_dir.exists() {
eprintln!("Error: artifacts directory not found!");
eprintln!("Please run get_model.py first to download the model and test data.");
std::process::exit(1);
}
// Initialize the model (using default which includes the converted weights)
println!("Initializing CLIP vision model...");
let start = Instant::now();
let device = Default::default();
let model: clip_vit_b_32_vision::Model<MyBackend> = clip_vit_b_32_vision::Model::default();
let init_time = start.elapsed();
println!(" Model initialized in {:.2?}", init_time);
// Save model structure to file
println!("\nSaving model structure to artifacts/model.txt...");
let model_str = format!("{}", model);
std::fs::write("artifacts/model.txt", &model_str)
.expect("Failed to write model structure to file");
println!(" Model structure saved");
// Load test data from PyTorch file
println!("\nLoading test data from artifacts/test_data.pt...");
let start = Instant::now();
let mut test_data = TestData::<MyBackend>::new(&device);
let mut store = PytorchStore::from_file("artifacts/test_data.pt");
test_data.load_from(&mut store).expect("Failed to load test data");
let load_time = start.elapsed();
println!(" Data loaded in {:.2?}", load_time);
// Get the input tensors from test data
let pixel_values = test_data.pixel_values.val();
let pixel_values_shape = pixel_values.shape();
println!(
" Loaded pixel_values with shape: {:?}",
pixel_values_shape.dims
);
// Get the reference outputs from test data
let reference_image_embeds = test_data.image_embeds.val();
let ref_image_embeds_shape = reference_image_embeds.shape();
println!(
" Loaded reference image_embeds with shape: {:?}",
ref_image_embeds_shape.dims
);
// Run inference with the loaded input
println!("\nRunning model inference with test input...");
let start = Instant::now();
let image_embeds = model.forward(pixel_values);
let inference_time = start.elapsed();
println!(" Inference completed in {:.2?}", inference_time);
// Display output shapes
let image_embeds_shape = image_embeds.shape();
println!("\n Model output shapes:");
println!(" image_embeds: {:?}", image_embeds_shape.dims);
// Verify expected output shapes match
if image_embeds_shape.dims == ref_image_embeds_shape.dims {
println!(
" ✓ image_embeds shape matches expected: {:?}",
ref_image_embeds_shape.dims
);
} else {
println!(
" ⚠ Warning: Expected image_embeds shape {:?}, got {:?}",
ref_image_embeds_shape.dims, image_embeds_shape.dims
);
}
// Compare outputs
println!("\nComparing model outputs with reference data...");
// Check if image_embeds are close
println!("\n Checking image_embeds:");
if image_embeds
.clone()
.all_close(reference_image_embeds.clone(), Some(1e-4), Some(1e-4))
{
println!(" ✓ image_embeds matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ image_embeds differs from reference data!");
// Calculate and display the difference statistics
let diff = image_embeds.clone() - reference_image_embeds.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = image_embeds.clone().flatten::<1>(0, 1);
let reference_flat = reference_image_embeds.clone().flatten::<1>(0, 1);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
println!("\n========================================");
println!("Summary:");
println!(" - Model initialization: {:.2?}", init_time);
println!(" - Data loading: {:.2?}", load_time);
println!(" - Inference time: {:.2?}", inference_time);
println!(" - Output validation: ✓ Passed");
println!("========================================");
println!("Model test completed successfully!");
println!("========================================");
}

View File

@@ -1,26 +0,0 @@
[package]
name = "burn-onnx-model-checks-modernbert-base"
version = "0.1.0"
edition = "2024"
publish = false
[workspace]
[features]
default = ["tch"]
ndarray = []
tch = []
wgpu = []
metal = []
[dependencies]
burn = { path = "../../../../crates/burn", features = [
"ndarray",
"tch",
"wgpu",
"metal",
] }
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
[build-dependencies]
burn-onnx = { path = "../../../burn-onnx" }

View File

@@ -1,72 +0,0 @@
# ModernBERT-base Model Check
This crate provides testing for the ModernBERT-base model with Burn.
## Model
- `ModernBERT-base` - Modern BERT variant from Answer.AI
(https://huggingface.co/answerdotai/ModernBERT-base)
## Usage
### 1. Download and prepare the model
```bash
# Using Python directly
python get_model.py
# Or using uv
uv run get_model.py
```
**Note:** This will download the model from HuggingFace and export it to ONNX format using PyTorch.
Make sure you have the `transformers` library installed.
### 2. Build and run the model test
```bash
# Build the model
cargo build
# Run the test
cargo run --release
```
## Directory Structure
```
modernbert-base/
├── artifacts/ # Downloaded ONNX model and test data
│ ├── modernbert-base_opset16.onnx
│ ├── test_data.pt
│ └── model-python.txt
├── src/
│ └── main.rs # Test runner
├── build.rs # Build script that generates model code
├── get_model.py # Model download and ONNX export script
├── Cargo.toml
└── README.md
```
## Model Architecture
ModernBERT is a modern variant of BERT designed by Answer.AI with improved efficiency and
performance. It features several architectural improvements over the original BERT, including better
positional embeddings and optimized attention mechanisms.
- **Inputs**:
- `input_ids`: Token IDs (shape: [batch_size, sequence_length])
- `attention_mask`: Attention mask (shape: [batch_size, sequence_length])
- **Outputs**:
- `last_hidden_state`: Sequence of hidden states (shape: [batch_size, sequence_length, 768])
- `pooled_output`: Mean-pooled sentence embeddings (computed, shape: [batch_size, 768])
## Notes
- The default sequence length is 512 tokens
- The model has a hidden size of 768
- The model uses ONNX opset 16
- Test data is generated with random inputs for reproducibility (seed=42)
- Vocabulary size: 50,368 tokens
- The pooled output is computed using mean pooling over the last hidden state

View File

@@ -1,32 +0,0 @@
use burn_onnx::ModelGen;
use std::path::Path;
fn main() {
let onnx_path = "artifacts/modernbert-base_opset16.onnx";
let test_data_path = "artifacts/test_data.pt";
// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed={}", test_data_path);
println!("cargo:rerun-if-changed=build.rs");
// Check if the ONNX model file exists
if !Path::new(onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!("Please run the following command to download and prepare the model:");
eprintln!(" python get_model.py");
eprintln!();
eprintln!("Or if you prefer using uv:");
eprintln!(" uv run get_model.py");
eprintln!();
eprintln!("This will download the ModernBERT-base model and convert it to ONNX format.");
std::process::exit(1);
}
// Generate the model code from the ONNX file
ModelGen::new()
.input(onnx_path)
.out_dir("model/")
.run_from_script();
}

View File

@@ -1,335 +0,0 @@
#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
# "onnx==1.19.0",
# "onnxruntime>=1.22.0",
# "huggingface-hub>=0.20.0",
# "numpy",
# "torch",
# "transformers>=4.46.0",
# ]
# ///
import os
import sys
import onnx
from onnx import shape_inference, version_converter
import numpy as np
from pathlib import Path
from huggingface_hub import hf_hub_download
def download_modernbert_model(output_path):
"""Download ModernBERT-base model from Hugging Face and export to ONNX."""
print("Downloading ModernBERT-base model from Hugging Face...")
try:
import torch
from transformers import AutoModel, AutoTokenizer
# Download the model
model_name = "answerdotai/ModernBERT-base"
print(f" Loading {model_name}...")
model = AutoModel.from_pretrained(model_name, cache_dir="./artifacts/cache")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./artifacts/cache")
# Set model to evaluation mode
model.eval()
# Create dummy input
dummy_text = "This is a test sentence."
inputs = tokenizer(
dummy_text, return_tensors="pt", padding="max_length", max_length=512
)
# Export to ONNX
print(f" Exporting model to ONNX...")
torch.onnx.export(
model,
(inputs["input_ids"], inputs["attention_mask"]),
str(output_path),
input_names=["input_ids", "attention_mask"],
output_names=["last_hidden_state"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"last_hidden_state": {0: "batch", 1: "sequence"},
},
opset_version=16,
)
print(f"✓ Model exported to ONNX: {output_path}")
except ImportError as e:
print(f"\n✗ Error: Missing required package.")
print(f" {str(e)}")
print("\nPlease install transformers:")
print(" uv pip install transformers")
sys.exit(1)
def process_model(input_path, output_path, target_opset=16):
"""Load, upgrade opset, and apply shape inference to model."""
print(f"Loading model from {input_path}...")
model = onnx.load(input_path)
# Check and upgrade opset if needed
current_opset = model.opset_import[0].version
if current_opset < target_opset:
print(f"Upgrading opset from {current_opset} to {target_opset}...")
model = version_converter.convert_version(model, target_opset)
# Apply shape inference
print("Applying shape inference...")
model = shape_inference.infer_shapes(model)
# Save processed model
onnx.save(model, output_path)
print(f"✓ Processed model saved to: {output_path}")
return model
def get_input_info(model):
"""Extract input information from ONNX model."""
inputs = []
for input_info in model.graph.input:
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
# Use proper defaults for ModernBERT
if (
"input_ids" in input_info.name
or "attention_mask" in input_info.name
):
# Default sequence length
shape.append(1 if len(shape) == 0 else 512)
else:
shape.append(1) # Default to 1 for other dynamic dimensions
inputs.append(
{
"name": input_info.name,
"shape": shape,
"dtype": input_info.type.tensor_type.elem_type,
}
)
return inputs
def generate_test_data(model_path, output_dir):
"""Generate test input/output data and save as PyTorch tensors."""
import torch
import onnxruntime as ort
print("\nGenerating test data...")
# Load model to get input shapes
model = onnx.load(model_path)
input_infos = get_input_info(model)
print(f" Model has {len(input_infos)} inputs:")
for info in input_infos:
print(f" - {info['name']}: shape={info['shape']}, dtype={info['dtype']}")
# Create reproducible test inputs
np.random.seed(42)
test_inputs = {}
for info in input_infos:
if info["dtype"] == onnx.TensorProto.INT64:
# Handle different integer inputs appropriately
if "attention_mask" in info["name"]:
# Attention mask should be 1 for valid tokens, 0 for padding
# For testing, use all 1s (all valid tokens)
test_input = np.ones(info["shape"], dtype=np.int64)
elif "input_ids" in info["name"]:
# For input_ids, use random integers in vocabulary range
# ModernBERT uses a 50368 vocabulary
test_input = np.random.randint(0, 50368, size=info["shape"], dtype=np.int64)
else:
# For other INT64 inputs, use random integers
test_input = np.random.randint(0, 50368, size=info["shape"], dtype=np.int64)
else:
# For float inputs, use random floats
test_input = np.random.rand(*info["shape"]).astype(np.float32)
test_inputs[info["name"]] = test_input
# Run inference to get output
session = ort.InferenceSession(model_path)
outputs = session.run(None, test_inputs)
# Save in a format that's easier to load in Rust
# For ModernBERT, we expect:
# - Inputs: input_ids, attention_mask
# - Outputs: last_hidden_state (3D)
# Compute mean pooled output
last_hidden_state = outputs[0]
attention_mask_np = test_inputs.get("attention_mask")
# Mean pooling - take attention mask into account for correct averaging
input_mask_expanded = np.expand_dims(attention_mask_np, axis=-1).astype(np.float32)
sum_embeddings = np.sum(last_hidden_state * input_mask_expanded, axis=1)
sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
pooled_output = sum_embeddings / sum_mask
# Create a more structured format for Rust
test_data = {
"input_ids": torch.from_numpy(
test_inputs.get(
"input_ids", test_inputs.get("inputs.0", list(test_inputs.values())[0])
)
),
"attention_mask": torch.from_numpy(
test_inputs.get(
"attention_mask",
test_inputs.get("inputs.1", list(test_inputs.values())[1]),
)
),
"last_hidden_state": torch.from_numpy(outputs[0]),
"pooled_output": torch.from_numpy(pooled_output),
}
test_data_path = Path(output_dir) / "test_data.pt"
torch.save(test_data, test_data_path)
print(f" ✓ Test data saved to: {test_data_path}")
print(f" Input shapes:")
print(f" input_ids: {test_data['input_ids'].shape}")
print(f" attention_mask: {test_data['attention_mask'].shape}")
print(f" Output shapes:")
print(f" last_hidden_state: {test_data['last_hidden_state'].shape}")
print(f" pooled_output: {test_data['pooled_output'].shape}")
def save_model_info(model_path, output_dir):
"""Save model structure information to a text file."""
print("\nSaving model information...")
model = onnx.load(model_path)
info_path = Path(output_dir) / "model-python.txt"
with open(info_path, "w") as f:
f.write("ModernBERT-base Model Information\n")
f.write("=" * 60 + "\n\n")
# Input information
f.write("Inputs:\n")
for input_info in model.graph.input:
f.write(f" - {input_info.name}\n")
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("dynamic")
f.write(f" Shape: {shape}\n")
f.write(
f" Type: {onnx.TensorProto.DataType.Name(input_info.type.tensor_type.elem_type)}\n"
)
# Output information
f.write("\nOutputs:\n")
for output_info in model.graph.output:
f.write(f" - {output_info.name}\n")
shape = []
for dim in output_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append("dynamic")
f.write(f" Shape: {shape}\n")
f.write(
f" Type: {onnx.TensorProto.DataType.Name(output_info.type.tensor_type.elem_type)}\n"
)
# Model statistics
f.write(f"\nModel Statistics:\n")
f.write(f" Opset version: {model.opset_import[0].version}\n")
f.write(f" Number of nodes: {len(model.graph.node)}\n")
f.write(f" Number of initializers: {len(model.graph.initializer)}\n")
# Node types summary
node_types = {}
for node in model.graph.node:
op_type = node.op_type
node_types[op_type] = node_types.get(op_type, 0) + 1
f.write(f"\nNode types:\n")
for op_type, count in sorted(node_types.items()):
f.write(f" {op_type}: {count}\n")
print(f" ✓ Model info saved to: {info_path}")
def main():
print("=" * 60)
print("ModernBERT-base Model Preparation Tool")
print("=" * 60)
# Setup paths
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
original_path = artifacts_dir / "modernbert-base.onnx"
processed_path = artifacts_dir / "modernbert-base_opset16.onnx"
test_data_path = artifacts_dir / "test_data.pt"
model_info_path = artifacts_dir / "model-python.txt"
# Check if we already have everything
if processed_path.exists() and test_data_path.exists() and model_info_path.exists():
print(f"\n✓ All files already exist:")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print(f" Model info: {model_info_path}")
print("\nNothing to do!")
return
# Download and export model if needed
if not original_path.exists() and not processed_path.exists():
print("\nStep 1: Downloading and exporting ModernBERT-base model...")
download_modernbert_model(original_path)
# Process model if needed (already at opset 16, but apply shape inference)
if not processed_path.exists():
if original_path.exists():
print("\nStep 2: Processing model...")
process_model(original_path, processed_path, target_opset=16)
# Clean up original if we have the processed version
if original_path.exists() and processed_path.exists():
original_path.unlink()
# Generate test data if needed
if not test_data_path.exists():
print("\nStep 3: Generating test data...")
generate_test_data(processed_path, artifacts_dir)
# Save model info if needed
if not model_info_path.exists():
print("\nStep 4: Saving model information...")
save_model_info(processed_path, artifacts_dir)
print("\n" + "=" * 60)
print("✓ ModernBERT-base model preparation completed!")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print(f" Model info: {model_info_path}")
print("=" * 60)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n⚠ Operation cancelled by user.")
sys.exit(1)
except Exception as e:
print(f"\n✗ Error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,271 +0,0 @@
extern crate alloc;
use burn::module::{Initializer, Param};
use burn::prelude::*;
use burn_store::{ModuleSnapshot, PytorchStore};
use std::path::Path;
use std::time::Instant;
#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;
#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;
#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;
#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;
// Import the generated model code as a module
pub mod modernbert_base {
include!(concat!(
env!("OUT_DIR"),
"/model/modernbert-base_opset16.rs"
));
}
#[derive(Debug, Module)]
struct TestData<B: Backend> {
input_ids: Param<Tensor<B, 2, Int>>,
attention_mask: Param<Tensor<B, 2, Int>>,
last_hidden_state: Param<Tensor<B, 3>>,
pooled_output: Param<Tensor<B, 2>>,
}
impl<B: Backend> TestData<B> {
fn new(device: &B::Device) -> Self {
use burn::module::ParamId;
// ModernBERT-base: sequence_length=512, hidden_size=768
// Note: Initializer only works for float tensors, Int tensors need manual init
Self {
input_ids: Param::initialized(ParamId::new(), Tensor::zeros([1, 512], device)),
attention_mask: Param::initialized(ParamId::new(), Tensor::zeros([1, 512], device)),
last_hidden_state: Initializer::Zeros.init([1, 512, 768], device),
pooled_output: Initializer::Zeros.init([1, 768], device),
}
}
}
/// Apply mean pooling to get sentence embeddings
fn mean_pool<B: Backend>(
last_hidden_state: Tensor<B, 3>,
attention_mask: Tensor<B, 2, Int>,
) -> Tensor<B, 2> {
// Convert attention_mask to float and expand dimensions to match hidden_state
let attention_mask_float = attention_mask.float().unsqueeze_dim::<3>(2);
// Multiply hidden states by attention mask
let masked_embeddings = last_hidden_state * attention_mask_float.clone();
// Sum along sequence dimension (dim 1)
let sum_embeddings = masked_embeddings.sum_dim(1);
// Sum attention mask to get count of non-padding tokens
let sum_mask = attention_mask_float.sum_dim(1).clamp_min(1e-9);
// Divide to get mean - result is [batch, 1, hidden]
let pooled = sum_embeddings / sum_mask;
// Get the shape to reshape to [batch, hidden]
let shape = pooled.shape();
let batch_size = shape.dims[0];
let hidden_size = shape.dims[2];
pooled.reshape([batch_size, hidden_size])
}
fn main() {
println!("========================================");
println!("ModernBERT-base Burn Model Test");
println!("========================================\n");
// Check if artifacts exist
let artifacts_dir = Path::new("artifacts");
if !artifacts_dir.exists() {
eprintln!("Error: artifacts directory not found!");
eprintln!("Please run get_model.py first to download the model and test data.");
std::process::exit(1);
}
// Initialize the model (using default which includes the converted weights)
println!("Initializing ModernBERT-base model...");
let start = Instant::now();
let device = Default::default();
let model: modernbert_base::Model<MyBackend> = modernbert_base::Model::default();
let init_time = start.elapsed();
println!(" Model initialized in {:.2?}", init_time);
// Save model structure to file
println!("\nSaving model structure to artifacts/model.txt...");
let model_str = format!("{}", model);
std::fs::write("artifacts/model.txt", &model_str)
.expect("Failed to write model structure to file");
println!(" Model structure saved");
// Load test data from PyTorch file
println!("\nLoading test data from artifacts/test_data.pt...");
let start = Instant::now();
let mut test_data = TestData::<MyBackend>::new(&device);
let mut store = PytorchStore::from_file("artifacts/test_data.pt");
test_data.load_from(&mut store).expect("Failed to load test data");
let load_time = start.elapsed();
println!(" Data loaded in {:.2?}", load_time);
// Get the input tensors from test data
let input_ids = test_data.input_ids.val();
let attention_mask = test_data.attention_mask.val();
let input_ids_shape = input_ids.shape();
let attention_mask_shape = attention_mask.shape();
println!(" Loaded input_ids with shape: {:?}", input_ids_shape.dims);
println!(
" Loaded attention_mask with shape: {:?}",
attention_mask_shape.dims
);
// Get the reference outputs from test data
let reference_last_hidden_state = test_data.last_hidden_state.val();
let reference_pooled_output = test_data.pooled_output.val();
let ref_last_hidden_shape = reference_last_hidden_state.shape();
let ref_pooled_shape = reference_pooled_output.shape();
println!(
" Loaded reference last_hidden_state with shape: {:?}",
ref_last_hidden_shape.dims
);
println!(
" Loaded reference pooled_output with shape: {:?}",
ref_pooled_shape.dims
);
// Run inference with the loaded input
println!("\nRunning model inference with test input...");
let start = Instant::now();
let last_hidden_state = model.forward(input_ids.clone(), attention_mask.clone());
let inference_time = start.elapsed();
println!(" Inference completed in {:.2?}", inference_time);
// Compute pooled output (mean pooling)
println!("\nComputing pooled output (mean pooling)...");
let start = Instant::now();
let pooled_output = mean_pool(last_hidden_state.clone(), attention_mask.clone());
let pooling_time = start.elapsed();
println!(" Pooling completed in {:.2?}", pooling_time);
// Display output shapes
let last_hidden_shape = last_hidden_state.shape();
let pooled_shape = pooled_output.shape();
println!("\n Model output shapes:");
println!(" last_hidden_state: {:?}", last_hidden_shape.dims);
println!(" pooled_output: {:?}", pooled_shape.dims);
// Verify expected output shapes match
if last_hidden_shape.dims == ref_last_hidden_shape.dims {
println!(
" ✓ last_hidden_state shape matches expected: {:?}",
ref_last_hidden_shape.dims
);
} else {
println!(
" ⚠ Warning: Expected last_hidden_state shape {:?}, got {:?}",
ref_last_hidden_shape.dims, last_hidden_shape.dims
);
}
if pooled_shape.dims == ref_pooled_shape.dims {
println!(
" ✓ pooled_output shape matches expected: {:?}",
ref_pooled_shape.dims
);
} else {
println!(
" ⚠ Warning: Expected pooled_output shape {:?}, got {:?}",
ref_pooled_shape.dims, pooled_shape.dims
);
}
// Compare outputs
println!("\nComparing model outputs with reference data...");
// Check if last_hidden_state is close
println!("\n Checking last_hidden_state:");
if last_hidden_state.clone().all_close(
reference_last_hidden_state.clone(),
Some(1e-4),
Some(1e-4),
) {
println!(" ✓ last_hidden_state matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ last_hidden_state differs from reference data!");
// Calculate and display the difference statistics
let diff = last_hidden_state.clone() - reference_last_hidden_state.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = last_hidden_state.clone().flatten::<1>(0, 2);
let reference_flat = reference_last_hidden_state.clone().flatten::<1>(0, 2);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
// Check if pooled_output is close
println!("\n Checking pooled_output:");
if pooled_output
.clone()
.all_close(reference_pooled_output.clone(), Some(1e-4), Some(1e-4))
{
println!(" ✓ pooled_output matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ pooled_output differs from reference data!");
// Calculate and display the difference statistics
let diff = pooled_output.clone() - reference_pooled_output.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = pooled_output.clone().flatten::<1>(0, 1);
let reference_flat = reference_pooled_output.clone().flatten::<1>(0, 1);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
println!("\n========================================");
println!("Model test completed!");
println!("========================================");
}

View File

@@ -1,26 +0,0 @@
[package]
name = "burn-onnx-model-checks-rf-detr"
version = "0.1.0"
edition = "2024"
publish = false
[workspace]
[features]
default = ["ndarray"]
ndarray = []
tch = []
wgpu = []
metal = []
[dependencies]
burn = { path = "../../../../crates/burn", features = [
"ndarray",
"tch",
"wgpu",
"metal",
] }
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
[build-dependencies]
burn-onnx = { path = "../../../burn-onnx" }

View File

@@ -1,85 +0,0 @@
# RF-DETR Model Check
This crate tests burn-onnx's ability to handle the RF-DETR (Roboflow DETR) object detection model.
## About RF-DETR
RF-DETR is a real-time object detection model based on the DETR (Detection Transformer)
architecture, developed by Roboflow. It combines transformer-based detection with optimizations for
speed and accuracy.
Key features:
- Transformer-based architecture with multi-head attention
- Deformable attention mechanisms
- Object queries for detection
- End-to-end trainable without anchor boxes or NMS (Non-Maximum Suppression)
## Related Issue
This model check was created to track and test the fix for:
- [Issue #4052](https://github.com/tracel-ai/burn/issues/4052): RF-DETR ONNX import fails with "axis
2 is out of bounds for rank 1"
## Usage
### 1. Download and prepare the model
Requires Python 3.11:
```bash
# Using uv (recommended)
uv run --python 3.11 get_model.py
```
### 2. Build and run the model test
```bash
cargo build
cargo run
```
## Directory Structure
```
rf-detr/
├── artifacts/ # Downloaded ONNX model and test data
│ ├── rf_detr_small.onnx # ONNX model (119 MB)
│ ├── rf_detr_small_test_data.pt # Test input/output tensors (3.1 MB)
│ └── node_info.json # ONNX node analysis
├── src/
│ └── main.rs # Test runner with output comparison
├── build.rs # Build script that generates model code
├── get_model.py # Model download, export, and test data generation
└── Cargo.toml
```
## Model Details
- **Model**: RF-DETR Small
- **Input**: `[1, 3, 512, 512]` (RGB image)
- **Outputs**:
- `dets`: `[1, 300, 4]` - 300 bounding boxes (x, y, w, h)
- `labels`: `[1, 300, 91]` - 300 class scores (91 COCO classes)
- **Architecture**: DETR with deformable attention
## Test Data
The `get_model.py` script generates reference test data by:
1. Creating a reproducible random input tensor (seed 42)
2. Running inference with ONNX Runtime
3. Saving both input and outputs as PyTorch tensors
When the ONNX import issue is fixed, `cargo run` will:
1. Load the test data
2. Run inference with the burn-generated model
3. Compare outputs against ONNX Runtime reference within tolerance (1e-4)
## Notes
- The model uses transformer layers which test burn-onnx's handling of attention mechanisms
- The model includes complex operations like multi-head attention and deformable convolutions
- Currently fails at build time due to issue #4052 (axis out of bounds during type inference)

View File

@@ -1,29 +0,0 @@
use burn_onnx::ModelGen;
use std::path::Path;
fn main() {
let onnx_path = "artifacts/rf_detr_small.onnx";
let test_data_path = "artifacts/rf_detr_small_test_data.pt";
// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed={}", test_data_path);
println!("cargo:rerun-if-changed=build.rs");
// Check if the ONNX model file exists
if !Path::new(onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!("Please run the following command to download and prepare the RF-DETR model:");
eprintln!(" uv run --python 3.11 get_model.py");
eprintln!();
eprintln!("This will download and export the RF-DETR Small model to ONNX format.");
std::process::exit(1);
}
// Generate the model code from the ONNX file
ModelGen::new()
.input(onnx_path)
.out_dir("model/")
.run_from_script();
}

View File

@@ -1,299 +0,0 @@
#!/usr/bin/env -S uv run --python 3.11 --script
# /// script
# python = "3.11"
# dependencies = [
# "torch==2.6.*",
# "torchvision==0.21.*",
# "onnx<1.17",
# "onnxruntime",
# "numpy",
# "rfdetr[onnxexport]",
# ]
# ///
"""
Download and prepare the RF-DETR model for testing with burn-onnx.
RF-DETR (Roboflow DETR) is a real-time object detection model that combines
the DETR (Detection Transformer) architecture with optimizations for speed.
This script exports the RFDETRSmall model to ONNX format for testing burn-onnx's
ability to handle transformer-based detection models.
Related issue: https://github.com/tracel-ai/burn/issues/4052
"""
import json
import shutil
from collections import defaultdict
from pathlib import Path
import numpy as np
import onnx
def get_input_shape(model):
"""Extract input shape from ONNX model."""
input_info = model.graph.input[0]
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append(1) # Default to 1 for dynamic dimensions
# Ensure valid RF-DETR input shape (batch, channels, height, width)
# RF-DETR uses 560x560 by default
if len(shape) != 4 or shape[2] == 0 or shape[2] > 2000:
return [1, 3, 560, 560]
return shape
def extract_node_info(model_path, artifacts_dir):
"""Extract node types and configurations from the ONNX model."""
print("Extracting node information from ONNX model...")
# Load the ONNX model
model = onnx.load(str(model_path))
# Collect node information
node_types = defaultdict(int)
node_details = []
def process_graph(graph, graph_name="main"):
"""Recursively process a graph and its subgraphs."""
for idx, node in enumerate(graph.node):
node_types[node.op_type] += 1
# Extract node details
node_info = {
"graph": graph_name,
"index": idx,
"op_type": node.op_type,
"name": node.name if node.name else f"{node.op_type}_{idx}",
"inputs": list(node.input),
"outputs": list(node.output),
"attributes": {},
}
# Extract attributes
for attr in node.attribute:
attr_name = attr.name
# Get attribute value based on type
if attr.HasField("f"):
node_info["attributes"][attr_name] = float(attr.f)
elif attr.HasField("i"):
node_info["attributes"][attr_name] = int(attr.i)
elif attr.HasField("s"):
node_info["attributes"][attr_name] = (
attr.s.decode("utf-8") if attr.s else ""
)
elif attr.HasField("t"):
node_info["attributes"][attr_name] = "<tensor>"
elif attr.floats:
node_info["attributes"][attr_name] = list(attr.floats)
elif attr.ints:
node_info["attributes"][attr_name] = list(attr.ints)
elif attr.strings:
node_info["attributes"][attr_name] = [
s.decode("utf-8") for s in attr.strings
]
elif attr.HasField("g"):
# Subgraph - recursively process it
subgraph_name = f"{graph_name}.{node.op_type}_{idx}.{attr_name}"
node_info["attributes"][attr_name] = f"<subgraph: {subgraph_name}>"
process_graph(attr.g, subgraph_name)
elif attr.graphs:
subgraph_names = []
for g_idx, subgraph in enumerate(attr.graphs):
subgraph_name = (
f"{graph_name}.{node.op_type}_{idx}.{attr_name}_{g_idx}"
)
subgraph_names.append(subgraph_name)
process_graph(subgraph, subgraph_name)
node_info["attributes"][attr_name] = (
f"<subgraphs: {', '.join(subgraph_names)}>"
)
else:
node_info["attributes"][attr_name] = "<unknown>"
node_details.append(node_info)
# Process the main graph
process_graph(model.graph, "main")
# Create summary
summary = {
"model_name": model.graph.name,
"opset_version": model.opset_import[0].version
if model.opset_import
else "unknown",
"total_nodes": len(node_details),
"node_type_counts": dict(sorted(node_types.items())),
"nodes": node_details,
}
# Save to JSON file
output_path = artifacts_dir / "node_info.json"
with open(output_path, "w") as f:
json.dump(summary, f, indent=2)
print(f" Node information extracted to {output_path}")
print(f" Total nodes: {summary['total_nodes']}")
print(f" Unique node types: {len(node_types)}")
print(f" Node type distribution:")
for op_type, count in sorted(node_types.items(), key=lambda x: x[1], reverse=True)[
:15
]:
print(f" - {op_type}: {count}")
if len(node_types) > 15:
print(f" ... and {len(node_types) - 15} more types")
return summary
def generate_test_data(model_path, output_path):
"""Generate test input/output data and save as PyTorch tensors."""
import onnxruntime as ort
import torch
print("Generating test data...")
# Load model to get input shape
model = onnx.load(str(model_path))
input_shape = get_input_shape(model)
print(f" Input shape: {input_shape}")
# Create reproducible test input
np.random.seed(42)
test_input = np.random.rand(*input_shape).astype(np.float32)
# Run inference to get output
session = ort.InferenceSession(str(model_path))
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: test_input})
# RF-DETR has two outputs: dets (boxes) and labels (class scores)
# Save as PyTorch tensors
test_data = {
"input": torch.from_numpy(test_input),
"output_dets": torch.from_numpy(outputs[0]),
"output_labels": torch.from_numpy(outputs[1]),
}
torch.save(test_data, output_path)
print(f" Test data saved to: {output_path}")
print(f" Input shape: {test_input.shape}")
print(f" Output dets shape: {outputs[0].shape}")
print(f" Output labels shape: {outputs[1].shape}")
def download_and_export_model():
"""Download RF-DETR model and export to ONNX format."""
from rfdetr import RFDETRSmall
# Create artifacts directory
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
model_path = artifacts_dir / "rf_detr_small.onnx"
test_data_path = artifacts_dir / "rf_detr_small_test_data.pt"
# Check if we already have everything
if model_path.exists() and test_data_path.exists():
print(f"All files already exist:")
print(f" Model: {model_path}")
print(f" Test data: {test_data_path}")
print("\nTo re-download, delete the artifacts directory and run again.")
return
print("=" * 60)
print("RF-DETR Small Model Preparation Tool")
print("=" * 60)
print()
# Download and export if model doesn't exist
if not model_path.exists():
# Download and initialize the model
print("Step 1: Downloading RF-DETR Small model...")
model = RFDETRSmall()
print(" Model downloaded and initialized")
# Export to ONNX using RF-DETR's built-in export method
print()
print("Step 2: Exporting model to ONNX format...")
# RF-DETR exports to output/inference_model.onnx by default
# Note: model.export() returns None, but exports to output/inference_model.onnx
model.export()
default_export_path = Path("output/inference_model.onnx")
print(f" Exported to: {default_export_path}")
# Move the exported file to artifacts directory
exported_file = default_export_path
if exported_file.exists():
shutil.move(str(exported_file), str(model_path))
# Clean up the output directory if empty
output_dir = exported_file.parent
if output_dir.exists() and not any(output_dir.iterdir()):
output_dir.rmdir()
# Clean up any downloaded weights file
pth_files = list(Path(".").glob("*.pth"))
for pth_file in pth_files:
pth_file.unlink()
print(f" Cleaned up: {pth_file}")
print(f" Model saved to {model_path}")
print(f" File size: {model_path.stat().st_size / 1024 / 1024:.1f} MB")
# Extract node information
print()
print("Step 3: Analyzing ONNX model structure...")
extract_node_info(model_path, artifacts_dir)
# Generate test data if needed
if not test_data_path.exists():
print()
print("Step 4: Generating test data...")
generate_test_data(model_path, test_data_path)
print()
print("=" * 60)
print("RF-DETR model preparation completed!")
print("=" * 60)
print()
print("The RF-DETR model is a transformer-based object detector that uses:")
print(" - Multi-head self-attention layers")
print(" - Cross-attention for object queries")
print(" - Deformable attention mechanisms")
print()
print("Generated files:")
print(f" - {model_path} (ONNX model)")
print(f" - {test_data_path} (test input/output data)")
print(f" - {artifacts_dir / 'node_info.json'} (node analysis)")
print()
print("Next steps:")
print(" 1. Build the model: cargo build")
print(" 2. Run the test: cargo run")
print()
print("Note: This model is used to test burn-onnx's handling of")
print("transformer-based architectures. Related issue:")
print(" https://github.com/tracel-ai/burn/issues/4052")
print()
if __name__ == "__main__":
try:
download_and_export_model()
except KeyboardInterrupt:
print("\nOperation cancelled by user.")
exit(1)
except Exception as e:
print(f"\nError: {e}")
import traceback
traceback.print_exc()
exit(1)

View File

@@ -1,238 +0,0 @@
extern crate alloc;
use burn::module::{Initializer, Param};
use burn::prelude::*;
use burn_store::PytorchStore;
use std::path::Path;
use std::time::Instant;
#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;
#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;
#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;
#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;
// Include the generated model
include!(concat!(env!("OUT_DIR"), "/model/rf_detr_small.rs"));
/// Test data structure matching the PyTorch saved format
/// RF-DETR has two outputs: dets (bounding boxes) and labels (class scores)
#[derive(Debug, Module)]
struct TestData<B: Backend> {
input: Param<Tensor<B, 4>>,
output_dets: Param<Tensor<B, 3>>,
output_labels: Param<Tensor<B, 3>>,
}
impl<B: Backend> TestData<B> {
fn new(device: &B::Device) -> Self {
// RF-DETR Small: input 512x512, 300 queries, 4 bbox coords, 91 classes (COCO)
Self {
input: Initializer::Zeros.init([1, 3, 512, 512], device),
output_dets: Initializer::Zeros.init([1, 300, 4], device),
output_labels: Initializer::Zeros.init([1, 300, 91], device),
}
}
}
fn main() {
println!("========================================");
println!("RF-DETR Small Model Test");
println!("========================================\n");
// Check if artifacts exist
let artifacts_dir = Path::new("artifacts");
if !artifacts_dir.exists() {
eprintln!("Error: artifacts directory not found!");
eprintln!("Please run get_model.py first to download the model.");
eprintln!("Example: uv run --python 3.11 get_model.py");
std::process::exit(1);
}
// Check if model file exists
let model_file = artifacts_dir.join("rf_detr_small.onnx");
let test_data_file = artifacts_dir.join("rf_detr_small_test_data.pt");
if !model_file.exists() {
eprintln!("Error: Model file not found!");
eprintln!("Please run: uv run --python 3.11 get_model.py");
std::process::exit(1);
}
if !test_data_file.exists() {
eprintln!("Error: Test data file not found!");
eprintln!("Please run: uv run --python 3.11 get_model.py");
std::process::exit(1);
}
// Initialize the model with weights
println!("Initializing RF-DETR Small model...");
let start = Instant::now();
let device = Default::default();
// The model weights are generated at build time and stored in the OUT_DIR
// We need to load them from the embedded burnpack file
let weights_path = concat!(env!("OUT_DIR"), "/model/rf_detr_small.bpk");
let model: Model<MyBackend> = Model::from_file(weights_path, &device);
let init_time = start.elapsed();
println!(" Model initialized in {:.2?}", init_time);
// Save model structure to file
let model_txt_path = artifacts_dir.join("rf_detr_small_model.txt");
println!(
"\nSaving model structure to {}...",
model_txt_path.display()
);
let model_str = format!("{}", model);
std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file");
println!(" Model structure saved");
// Load test data from PyTorch file
println!("\nLoading test data from {}...", test_data_file.display());
let start = Instant::now();
let mut test_data = TestData::<MyBackend>::new(&device);
let mut store = PytorchStore::from_file(&test_data_file);
test_data.load_from(&mut store).expect("Failed to load test data");
let load_time = start.elapsed();
println!(" Data loaded in {:.2?}", load_time);
// Get the input tensor from test data
let input = test_data.input.val();
let input_shape = input.shape();
println!(" Loaded input tensor with shape: {:?}", input_shape.dims);
// Get the reference outputs from test data
let reference_dets = test_data.output_dets.val();
let reference_labels = test_data.output_labels.val();
println!(
" Loaded reference dets with shape: {:?}",
reference_dets.shape().dims
);
println!(
" Loaded reference labels with shape: {:?}",
reference_labels.shape().dims
);
// Run inference with the loaded input
println!("\nRunning model inference with test input...");
let start = Instant::now();
let (output_dets, output_labels) = model.forward(input);
let inference_time = start.elapsed();
println!(" Inference completed in {:.2?}", inference_time);
// Display output shapes
println!("\nModel outputs:");
println!(" Dets shape: {:?}", output_dets.shape().dims);
println!(" Labels shape: {:?}", output_labels.shape().dims);
// Compare outputs
println!("\nComparing model outputs with reference data...");
let mut dets_passed = false;
let mut labels_passed = false;
// Check if dets are close
println!("\n Checking dets (bounding boxes):");
if output_dets
.clone()
.all_close(reference_dets.clone(), Some(1e-4), Some(1e-4))
{
println!(" ✓ dets matches reference data within tolerance (1e-4)!");
dets_passed = true;
} else {
println!(" ⚠ dets differs from reference data!");
// Calculate and display the difference statistics
let diff = output_dets.clone() - reference_dets.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = output_dets.clone().flatten::<1>(0, 2);
let reference_flat = reference_dets.clone().flatten::<1>(0, 2);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
// Check if labels are close
println!("\n Checking labels (class scores):");
if output_labels
.clone()
.all_close(reference_labels.clone(), Some(1e-4), Some(1e-4))
{
println!(" ✓ labels matches reference data within tolerance (1e-4)!");
labels_passed = true;
} else {
println!(" ⚠ labels differs from reference data!");
// Calculate and display the difference statistics
let diff = output_labels.clone() - reference_labels.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = output_labels.clone().flatten::<1>(0, 2);
let reference_flat = reference_labels.clone().flatten::<1>(0, 2);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
println!("\n========================================");
println!("Summary:");
println!(" - Model initialization: {:.2?}", init_time);
println!(" - Data loading: {:.2?}", load_time);
println!(" - Inference time: {:.2?}", inference_time);
if dets_passed && labels_passed {
println!(" - Output validation: ✓ All outputs match!");
} else {
println!(
" - Output validation: {} dets, {} labels",
if dets_passed { "" } else { "" },
if labels_passed { "" } else { "" }
);
}
println!("========================================");
if dets_passed && labels_passed {
println!("Model test completed successfully!");
} else {
println!("Model test completed with differences.");
}
println!("========================================");
}

View File

@@ -1,28 +0,0 @@
[package]
name = "burn-onnx-model-checks-silero-vad"
version = "0.1.0"
edition = "2024"
publish = false
[workspace]
[features]
default = ["ndarray"]
ndarray = []
tch = []
wgpu = []
metal = []
[dependencies]
burn = { path = "../../../../crates/burn", features = [
"ndarray",
"tch",
"wgpu",
"metal",
] }
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
[build-dependencies]
burn-onnx = { path = "../../../burn-onnx" }

View File

@@ -1,102 +0,0 @@
# Silero VAD Model Check
This model check verifies that burn-onnx can correctly handle the Silero VAD (Voice Activity
Detection) model and produces outputs matching ONNX Runtime.
## Current Status
**Working**: The model is successfully imported and produces outputs matching ONNX Runtime.
## Model Information
- **Model**: Silero VAD (opset 18, if-less version)
- **Source**: https://github.com/snakers4/silero-vad
- **Purpose**: Voice Activity Detection
- **Key Features**: Uses Conv, Gemm, and a single If node for sample rate selection
The if-less version has only 1 If node (for 16kHz vs 8kHz sample rate selection), making it
much simpler than the full model which has 25 If nodes.
See: https://github.com/snakers4/silero-vad/issues/728 for compatibility discussion.
## Setup
### Step 1: Download the Model and Generate Reference Outputs
```bash
# Using Python
python get_model.py
# Or using uv
uv run get_model.py
```
This downloads:
- `artifacts/silero_vad.onnx` - The ONNX model file
- `artifacts/node_info.json` - Detailed analysis of all nodes, operators, and configurations
- `artifacts/test.wav` - Test audio file from silero-vad repository
- `artifacts/reference_outputs.json` - Reference outputs from ONNX Runtime for validation
### Step 2: Build and Run Tests
```bash
cargo build
cargo run
```
The test suite runs 12 test cases:
- 10 audio chunks from the test.wav file
- 1 random input test (reproducible with seed 42)
- 1 silence test (all zeros)
Each test compares the Burn model output against ONNX Runtime reference outputs with a 1% tolerance.
## Backend Support
This model check supports multiple backends:
```bash
# NdArray backend (default, CPU)
cargo run
# LibTorch backend (CPU/CUDA)
cargo run --features tch --no-default-features
# WGPU backend (GPU via WebGPU)
cargo run --features wgpu --no-default-features
# Metal backend (Apple Silicon GPU)
cargo run --features metal --no-default-features
```
## Test Output Example
```
========================================
Silero VAD Model Test Suite
========================================
Loading reference outputs...
Loaded 12 test cases (sample rate: 16000 Hz)
Initializing Silero VAD model...
Model initialized
Running test cases...
------------------------------------------------------------
[PASS] chunk_0: output=0.000589 (expected=0.000589)
[PASS] chunk_1: output=0.000589 (expected=0.000528)
...
[PASS] silence: output=0.000592 (expected=0.000592)
------------------------------------------------------------
========================================
Test Summary
========================================
Total tests: 12
Passed: 12
Failed: 0
All tests passed!
The Burn model produces outputs matching ONNX Runtime.
```

View File

@@ -1,30 +0,0 @@
use burn_onnx::ModelGen;
use std::path::Path;
fn main() {
let onnx_path = "artifacts/silero_vad.onnx";
// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed=build.rs");
// Check if the ONNX model file exists
if !Path::new(onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!("Please run the following command to download the model:");
eprintln!(" python get_model.py");
eprintln!();
eprintln!("Or if you prefer using uv:");
eprintln!(" uv run get_model.py");
eprintln!();
eprintln!("This will download the Silero VAD model.");
std::process::exit(1);
}
// Generate the model code from the ONNX file
ModelGen::new()
.input(onnx_path)
.out_dir("model/")
.run_from_script();
}

View File

@@ -1,399 +0,0 @@
#!/usr/bin/env python3
"""
Download and prepare the Silero VAD model for testing.
This script downloads the Silero VAD ONNX model (opset 18, if-less version) and prepares
it for use with burn-onnx. This version has only 1 If node (for sample rate selection)
making it compatible with static type inference.
See: https://github.com/snakers4/silero-vad/issues/728 for compatibility discussion.
"""
import json
import struct
import urllib.request
import wave
from pathlib import Path
from collections import defaultdict
import numpy as np
try:
import onnx
except ImportError:
print("Error: onnx package not found. Please install it with:")
print(" pip install onnx")
exit(1)
try:
import onnxruntime as ort
except ImportError:
print("Error: onnxruntime package not found. Please install it with:")
print(" pip install onnxruntime")
exit(1)
def extract_node_info(model_path, artifacts_dir):
"""Extract node types and configurations from the ONNX model."""
print("Extracting node information from ONNX model...")
# Load the ONNX model (without external data since we only need structure)
model = onnx.load(str(model_path), load_external_data=False)
# Check for external data
external_files = set()
for init in model.graph.initializer:
if init.data_location == onnx.TensorProto.EXTERNAL:
for ext_data in init.external_data:
if ext_data.key == 'location':
external_files.add(ext_data.value)
if external_files:
print(f"⚠️ Model requires external data files: {external_files}")
print(" These files are missing from the repository!")
# Collect node information
node_types = defaultdict(int)
node_details = []
def process_graph(graph, graph_name="main"):
"""Recursively process a graph and its subgraphs."""
for idx, node in enumerate(graph.node):
node_types[node.op_type] += 1
# Extract node details
node_info = {
"graph": graph_name,
"index": idx,
"op_type": node.op_type,
"name": node.name if node.name else f"{node.op_type}_{idx}",
"inputs": list(node.input),
"outputs": list(node.output),
"attributes": {}
}
# Extract attributes
for attr in node.attribute:
attr_name = attr.name
# Get attribute value based on type
if attr.HasField('f'):
node_info["attributes"][attr_name] = float(attr.f)
elif attr.HasField('i'):
node_info["attributes"][attr_name] = int(attr.i)
elif attr.HasField('s'):
node_info["attributes"][attr_name] = attr.s.decode('utf-8') if attr.s else ""
elif attr.HasField('t'):
node_info["attributes"][attr_name] = "<tensor>"
elif attr.floats:
node_info["attributes"][attr_name] = list(attr.floats)
elif attr.ints:
node_info["attributes"][attr_name] = list(attr.ints)
elif attr.strings:
node_info["attributes"][attr_name] = [s.decode('utf-8') for s in attr.strings]
elif attr.HasField('g'):
# Subgraph - recursively process it
subgraph_name = f"{graph_name}.{node.op_type}_{idx}.{attr_name}"
node_info["attributes"][attr_name] = f"<subgraph: {subgraph_name}>"
process_graph(attr.g, subgraph_name)
elif attr.graphs:
subgraph_names = []
for g_idx, subgraph in enumerate(attr.graphs):
subgraph_name = f"{graph_name}.{node.op_type}_{idx}.{attr_name}_{g_idx}"
subgraph_names.append(subgraph_name)
process_graph(subgraph, subgraph_name)
node_info["attributes"][attr_name] = f"<subgraphs: {', '.join(subgraph_names)}>"
else:
node_info["attributes"][attr_name] = "<unknown>"
node_details.append(node_info)
# Process the main graph
process_graph(model.graph, "main")
# Create summary
summary = {
"model_name": model.graph.name,
"opset_version": model.opset_import[0].version if model.opset_import else "unknown",
"total_nodes": len(node_details),
"node_type_counts": dict(sorted(node_types.items())),
"external_data_files": list(external_files),
"nodes": node_details
}
# Save to JSON file
output_path = artifacts_dir / "node_info.json"
with open(output_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"✓ Node information extracted to {output_path}")
print(f" Opset version: {summary['opset_version']}")
print(f" Total nodes: {summary['total_nodes']}")
print(f" Unique node types: {len(node_types)}")
print(f" Node type distribution:")
for op_type, count in sorted(node_types.items(), key=lambda x: x[1], reverse=True):
print(f" - {op_type}: {count}")
return summary
def download_test_data():
"""Download test audio file from silero-vad repository."""
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
test_wav_path = artifacts_dir / "test.wav"
if test_wav_path.exists():
print(f"✓ Test audio already exists at {test_wav_path}")
return test_wav_path
# Download the test.wav file from silero-vad tests
test_url = "https://github.com/snakers4/silero-vad/raw/refs/heads/master/tests/data/test.wav"
print(f"Downloading test audio from:")
print(f" {test_url}")
print(f"Saving to: {test_wav_path}")
try:
urllib.request.urlretrieve(test_url, test_wav_path)
file_size = test_wav_path.stat().st_size / 1024
print(f"✓ Download complete! File size: {file_size:.1f} KB")
except Exception as e:
print(f"✗ Error downloading test audio: {e}")
raise
return test_wav_path
def load_wav(wav_path):
"""Load a WAV file and return audio samples as float32 array normalized to [-1, 1]."""
with wave.open(str(wav_path), 'rb') as wav_file:
n_channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
frame_rate = wav_file.getframerate()
n_frames = wav_file.getnframes()
# Read raw bytes
raw_data = wav_file.readframes(n_frames)
# Convert to numpy array based on sample width
if sample_width == 2: # 16-bit
samples = np.frombuffer(raw_data, dtype=np.int16)
# Normalize to [-1, 1]
samples = samples.astype(np.float32) / 32768.0
elif sample_width == 4: # 32-bit
samples = np.frombuffer(raw_data, dtype=np.int32)
samples = samples.astype(np.float32) / 2147483648.0
else:
raise ValueError(f"Unsupported sample width: {sample_width}")
# If stereo, convert to mono by averaging channels
if n_channels == 2:
samples = samples.reshape(-1, 2).mean(axis=1)
return samples, frame_rate
def generate_reference_outputs(model_path, test_wav_path, artifacts_dir):
"""Generate reference outputs using ONNX Runtime for testing."""
print(f" Loading test audio from {test_wav_path}...")
audio_samples, sample_rate = load_wav(test_wav_path)
print(f" Sample rate: {sample_rate} Hz")
print(f" Audio length: {len(audio_samples)} samples ({len(audio_samples)/sample_rate:.2f} seconds)")
# Create ONNX Runtime session
print(f" Creating ONNX Runtime session...")
session = ort.InferenceSession(str(model_path), providers=['CPUExecutionProvider'])
# Silero VAD parameters
# For 16kHz: chunk_size = 512 (32ms)
# For 8kHz: chunk_size = 256 (32ms)
chunk_size = 512 if sample_rate == 16000 else 256
batch_size = 1
# Initialize state
state = np.zeros((2, batch_size, 128), dtype=np.float32)
# Process audio in chunks and collect outputs
results = {
"sample_rate": sample_rate,
"chunk_size": chunk_size,
"audio_length_samples": len(audio_samples),
"test_cases": []
}
# Test Case 1: First few chunks from the beginning of the audio
print(f" Running inference on test chunks...")
num_test_chunks = min(10, len(audio_samples) // chunk_size)
state = np.zeros((2, batch_size, 128), dtype=np.float32)
for i in range(num_test_chunks):
start_idx = i * chunk_size
end_idx = start_idx + chunk_size
chunk = audio_samples[start_idx:end_idx]
if len(chunk) < chunk_size:
# Pad with zeros if needed
chunk = np.pad(chunk, (0, chunk_size - len(chunk)), mode='constant')
# Prepare inputs
input_tensor = chunk.reshape(1, -1).astype(np.float32)
sr_tensor = np.array(sample_rate, dtype=np.int64)
# Run inference
outputs = session.run(None, {
'input': input_tensor,
'sr': sr_tensor,
'state': state
})
output_prob = float(outputs[0].flatten()[0])
state = outputs[1]
results["test_cases"].append({
"test_name": f"chunk_{i}",
"chunk_index": i,
"start_sample": start_idx,
"input_samples": chunk.tolist(),
"expected_output": output_prob,
"state_after": state.tolist()
})
# Test Case 2: Random input (for reproducibility test)
print(f" Generating random input test case...")
np.random.seed(42)
random_input = np.random.randn(chunk_size).astype(np.float32) * 0.1
state = np.zeros((2, batch_size, 128), dtype=np.float32)
outputs = session.run(None, {
'input': random_input.reshape(1, -1),
'sr': np.array(sample_rate, dtype=np.int64),
'state': state
})
results["test_cases"].append({
"test_name": "random_seed_42",
"chunk_index": -1,
"start_sample": -1,
"input_samples": random_input.tolist(),
"expected_output": float(outputs[0].flatten()[0]),
"state_after": outputs[1].tolist()
})
# Test Case 3: Zero input (silence)
print(f" Generating silence test case...")
zero_input = np.zeros(chunk_size, dtype=np.float32)
state = np.zeros((2, batch_size, 128), dtype=np.float32)
outputs = session.run(None, {
'input': zero_input.reshape(1, -1),
'sr': np.array(sample_rate, dtype=np.int64),
'state': state
})
results["test_cases"].append({
"test_name": "silence",
"chunk_index": -1,
"start_sample": -1,
"input_samples": zero_input.tolist(),
"expected_output": float(outputs[0].flatten()[0]),
"state_after": outputs[1].tolist()
})
# Save results
output_path = artifacts_dir / "reference_outputs.json"
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
print(f" ✓ Reference outputs saved to {output_path}")
print(f" Total test cases: {len(results['test_cases'])}")
def download_model():
"""Download the Silero VAD ONNX model (opset 18, if-less version)."""
# Create artifacts directory if it doesn't exist
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
model_path = artifacts_dir / "silero_vad.onnx"
model_existed = model_path.exists()
# Skip download if model already exists
if model_existed:
print(f"✓ Model already exists at {model_path}")
print(f" File size: {model_path.stat().st_size / 1024:.1f} KB")
print()
else:
# Download the opset 18 if-less model
# Note: This model has external data that is missing from the repo
model_url = "https://github.com/snakers4/silero-vad/raw/refs/heads/master/src/silero_vad/data/silero_vad_op18_ifless.onnx"
print(f"Downloading Silero VAD model (opset 18, if-less) from:")
print(f" {model_url}")
print(f"Saving to: {model_path}")
print()
try:
urllib.request.urlretrieve(model_url, model_path)
file_size = model_path.stat().st_size / 1024
print(f"✓ Download complete! File size: {file_size:.1f} KB")
print()
except Exception as e:
print(f"✗ Error downloading model: {e}")
raise
# Extract node information from the ONNX model
try:
extract_node_info(model_path, artifacts_dir)
except Exception as e:
print(f"✗ Error extracting node information: {e}")
raise
print()
# Download test data
print("Downloading test data...")
try:
test_wav_path = download_test_data()
except Exception as e:
print(f"✗ Error downloading test data: {e}")
raise
print()
# Generate reference outputs
print("Generating reference outputs...")
try:
generate_reference_outputs(model_path, test_wav_path, artifacts_dir)
except Exception as e:
print(f"✗ Error generating reference outputs: {e}")
raise
print()
print("="*80)
print("Model preparation complete!")
print("="*80)
print()
print("Silero VAD (opset 18, if-less) has only 1 If node for sample rate selection.")
print("This makes it compatible with burn-onnx's static type inference.")
print()
print("See: https://github.com/snakers4/silero-vad/issues/728")
print()
print("Generated files:")
print(f" - {model_path} (ONNX model)")
print(f" - {artifacts_dir / 'node_info.json'} (node analysis)")
print(f" - {test_wav_path} (test audio)")
print(f" - {artifacts_dir / 'reference_outputs.json'} (reference test outputs)")
print()
print("Next steps:")
print(" 1. Build the model: cargo build")
print(" 2. Run the test: cargo run")
print()
if __name__ == "__main__":
download_model()

View File

@@ -1,164 +0,0 @@
extern crate alloc;
use burn::prelude::*;
use serde::Deserialize;
use std::fs;
use std::path::Path;
#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;
#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;
#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;
#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;
// Include the generated model
include!(concat!(env!("OUT_DIR"), "/model/silero_vad.rs"));
/// Test case from reference outputs
#[derive(Debug, Deserialize)]
struct TestCase {
test_name: String,
#[allow(dead_code)]
chunk_index: i32,
#[allow(dead_code)]
start_sample: i32,
input_samples: Vec<f32>,
expected_output: f32,
#[allow(dead_code)]
state_after: Vec<Vec<Vec<f32>>>,
}
/// Reference outputs structure
#[derive(Debug, Deserialize)]
struct ReferenceOutputs {
sample_rate: i64,
#[allow(dead_code)]
chunk_size: usize,
#[allow(dead_code)]
audio_length_samples: usize,
test_cases: Vec<TestCase>,
}
/// Run a single test case and return (passed, actual_output, expected_output)
fn run_test_case(
model: &Model<MyBackend>,
device: &<MyBackend as Backend>::Device,
test_case: &TestCase,
sample_rate: i64,
) -> (bool, f32, f32) {
// Create input tensor from test case samples
let input_data: Vec<f32> = test_case.input_samples.clone();
let input = Tensor::<MyBackend, 1>::from_floats(input_data.as_slice(), device)
.reshape([1, test_case.input_samples.len()]);
// Initialize state to zeros
let state = Tensor::<MyBackend, 3>::zeros([2, 1, 128], device);
// Run inference
let (output, _state_out) = model.forward(input, sample_rate, state);
// Get the output probability
let actual_output: f32 = output.into_scalar();
let expected_output = test_case.expected_output;
// Compare with tolerance (neural networks have small floating point differences)
let tolerance = 0.01; // 1% tolerance
let passed = (actual_output - expected_output).abs() < tolerance;
(passed, actual_output, expected_output)
}
fn main() {
println!("========================================");
println!("Silero VAD Model Test Suite");
println!("========================================\n");
// Check if artifacts exist
let artifacts_dir = Path::new("artifacts");
if !artifacts_dir.exists() {
eprintln!("Error: artifacts directory not found!");
eprintln!("Please run get_model.py first to download the model.");
eprintln!("Example: uv run get_model.py");
std::process::exit(1);
}
// Check if reference outputs exist
let reference_path = artifacts_dir.join("reference_outputs.json");
if !reference_path.exists() {
eprintln!("Error: reference_outputs.json not found!");
eprintln!("Please run: uv run get_model.py");
std::process::exit(1);
}
// Load reference outputs
println!("Loading reference outputs...");
let reference_json = fs::read_to_string(&reference_path).expect("Failed to read reference outputs");
let reference: ReferenceOutputs =
serde_json::from_str(&reference_json).expect("Failed to parse reference outputs");
println!(
" Loaded {} test cases (sample rate: {} Hz)\n",
reference.test_cases.len(),
reference.sample_rate
);
// Initialize the model
println!("Initializing Silero VAD model...");
let device = Default::default();
let model: Model<MyBackend> = Model::default();
println!(" Model initialized\n");
// Run tests
println!("Running test cases...");
println!("{:-<60}", "");
let mut passed_count = 0;
let mut failed_count = 0;
for test_case in &reference.test_cases {
let (passed, actual, expected) = run_test_case(&model, &device, test_case, reference.sample_rate);
if passed {
println!(
" [PASS] {}: output={:.6} (expected={:.6})",
test_case.test_name, actual, expected
);
passed_count += 1;
} else {
println!(
" [FAIL] {}: output={:.6} (expected={:.6}, diff={:.6})",
test_case.test_name,
actual,
expected,
(actual - expected).abs()
);
failed_count += 1;
}
}
println!("{:-<60}", "");
println!();
// Summary
println!("========================================");
println!("Test Summary");
println!("========================================");
println!(" Total tests: {}", passed_count + failed_count);
println!(" Passed: {}", passed_count);
println!(" Failed: {}", failed_count);
println!();
if failed_count == 0 {
println!("All tests passed!");
println!("The Burn model produces outputs matching ONNX Runtime.");
} else {
println!("Some tests failed!");
println!("The Burn model outputs differ from ONNX Runtime.");
std::process::exit(1);
}
}

View File

@@ -1,166 +0,0 @@
#!/usr/bin/env python3
"""
Validate the Silero VAD ONNX model independently.
This script:
1. Checks if the model is valid using onnx.checker
2. Runs inference using ONNX Runtime to verify it works
"""
import numpy as np
from pathlib import Path
try:
import onnx
from onnx import checker
except ImportError:
print("Error: onnx package not found. Please install it with:")
print(" pip install onnx")
exit(1)
try:
import onnxruntime as ort
except ImportError:
print("Error: onnxruntime package not found. Please install it with:")
print(" pip install onnxruntime")
exit(1)
def validate_model():
"""Validate the Silero VAD ONNX model."""
model_path = Path("artifacts/silero_vad.onnx")
if not model_path.exists():
print(f"Error: Model not found at {model_path}")
print("Please run 'python get_model.py' first to download the model.")
return False
print("=" * 80)
print("Silero VAD ONNX Model Validation")
print("=" * 80)
print()
# Step 1: Load and check the model structure
print("Step 1: Loading ONNX model...")
try:
model = onnx.load(str(model_path))
print(f" ✓ Model loaded successfully")
print(f" - Opset version: {model.opset_import[0].version}")
print(f" - Graph name: {model.graph.name}")
print(f" - Number of nodes: {len(model.graph.node)}")
print(f" - Number of initializers: {len(model.graph.initializer)}")
except Exception as e:
print(f" ✗ Failed to load model: {e}")
return False
print()
# Step 2: Validate with ONNX checker
print("Step 2: Validating model with onnx.checker...")
try:
checker.check_model(model, full_check=True)
print(" ✓ Model passed ONNX validation (full check)")
except Exception as e:
print(f" ✗ Model validation failed: {e}")
return False
print()
# Step 3: Print model inputs/outputs
print("Step 3: Model inputs and outputs...")
print(" Inputs:")
for inp in model.graph.input:
shape = []
for dim in inp.type.tensor_type.shape.dim:
if dim.dim_param:
shape.append(dim.dim_param)
else:
shape.append(dim.dim_value)
dtype = onnx.TensorProto.DataType.Name(inp.type.tensor_type.elem_type)
print(f" - {inp.name}: {dtype} {shape}")
print(" Outputs:")
for out in model.graph.output:
shape = []
for dim in out.type.tensor_type.shape.dim:
if dim.dim_param:
shape.append(dim.dim_param)
else:
shape.append(dim.dim_value)
dtype = onnx.TensorProto.DataType.Name(out.type.tensor_type.elem_type)
print(f" - {out.name}: {dtype} {shape}")
print()
# Step 4: Run inference with ONNX Runtime
print("Step 4: Running inference with ONNX Runtime...")
try:
# Create inference session
session = ort.InferenceSession(str(model_path), providers=['CPUExecutionProvider'])
print(" ✓ ONNX Runtime session created successfully")
# Get input details
input_details = session.get_inputs()
output_details = session.get_outputs()
print(f" - Session inputs: {[i.name for i in input_details]}")
print(f" - Session outputs: {[o.name for o in output_details]}")
# Prepare sample inputs based on model signature
# Silero VAD expects:
# - input: audio chunk [batch, samples] - typically 512 samples for 16kHz
# - sr: sample rate (int64)
# - h: hidden state [2, batch, 64]
# - c: cell state [2, batch, 64]
batch_size = 1
# Silero VAD chunk sizes:
# - 16kHz: 512 samples (32ms)
# - 8kHz: 256 samples (32ms)
# The model also supports larger chunks: 768, 1024, 1536 for 16kHz
chunk_size = 512 # 512 samples for 16kHz (32ms)
# Prepare inputs based on Silero VAD documentation
# https://github.com/snakers4/silero-vad
inputs = {
'input': np.random.randn(batch_size, chunk_size).astype(np.float32),
'sr': np.array(16000, dtype=np.int64), # 16kHz sample rate
'state': np.zeros((2, batch_size, 128), dtype=np.float32), # LSTM hidden state
}
for name, value in inputs.items():
print(f" - Input '{name}': shape={value.shape}, dtype={value.dtype}")
# Run inference
print()
print(" Running inference...")
outputs = session.run(None, inputs)
print(" ✓ Inference completed successfully!")
print()
print(" Output values:")
for i, (out_detail, out_value) in enumerate(zip(output_details, outputs)):
print(f" - {out_detail.name}: shape={out_value.shape}, dtype={out_value.dtype}")
if out_value.size <= 10:
print(f" values: {out_value}")
else:
print(f" sample: {out_value.flat[:5]}...")
except Exception as e:
print(f" ✗ ONNX Runtime inference failed: {e}")
import traceback
traceback.print_exc()
return False
print()
print("=" * 80)
print("✓ All validation checks passed!")
print("=" * 80)
print()
print("The ONNX model is valid and can run inference successfully.")
print("The issue with burn-onnx is in the code generation, not the model itself.")
return True
if __name__ == "__main__":
success = validate_model()
exit(0 if success else 1)

View File

@@ -1,26 +0,0 @@
[package]
name = "burn-onnx-model-checks-yolo"
version = "0.1.0"
edition = "2024"
publish = false
[workspace]
[features]
default = ["tch"]
ndarray = []
tch = []
wgpu = []
metal = []
[dependencies]
burn = { path = "../../../../crates/burn", features = [
"ndarray",
"tch",
"wgpu",
"metal",
] }
burn-store = { path = "../../../../crates/burn-store", features = ["burnpack", "pytorch"] }
[build-dependencies]
burn-onnx = { path = "../../../burn-onnx" }

View File

@@ -1,56 +0,0 @@
# YOLO Model Checks
This crate provides a unified interface for testing multiple YOLO model variants with Burn.
## Supported Models
- `yolov5s` - YOLOv5 small variant
- `yolov8n` - YOLOv8 nano variant
- `yolov8s` - YOLOv8 small variant
- `yolov10n` - YOLOv10 nano variant (Note: Currently fails due to TopK operator issue)
- `yolo11x` - YOLO11 extra-large variant
- `yolo12x` - YOLO12 extra-large variant
## Usage
### 1. Download and prepare a model
```bash
# Using Python directly
python get_model.py --model yolov8n
# Or using uv
uv run get_model.py --model yolov8n
# List available models
uv run get_model.py --list
```
### 2. Run the model test
After building, you can run the test. The model is already compiled in:
```bash
YOLO_MODEL=yolov8s cargo run --release
```
## Directory Structure
```
yolo/
├── artifacts/ # Downloaded ONNX models and test data
│ ├── yolov8n_opset16.onnx
│ ├── yolov8n_test_data.pt
│ └── ...
├── src/
│ └── main.rs # Test runner
├── build.rs # Build script that generates model code
├── get_model.py # Model download and preparation script
└── Cargo.toml
```
## Notes
- All YOLO models (except v10) output shape `[1, 84, 8400]` for standard object detection
- YOLOv10n has a different architecture with output shape `[1, 300, 6]` and uses TopK operator
- The crate requires explicit model selection at build time (no default model)

View File

@@ -1,87 +0,0 @@
use burn_onnx::ModelGen;
use std::env;
use std::fs;
use std::path::Path;
fn main() {
// Supported models
let supported_models = vec![
"yolov5s", "yolov8n", "yolov8s", "yolov10n", "yolo11x", "yolo12x",
];
// Get the model name from environment variable (required)
let model_name = env::var("YOLO_MODEL").unwrap_or_else(|_| {
eprintln!("Error: YOLO_MODEL environment variable is not set.");
eprintln!();
eprintln!("Please specify which YOLO model to build:");
eprintln!(" YOLO_MODEL=yolov8n cargo build");
eprintln!();
eprintln!("Available models: {}", supported_models.join(", "));
std::process::exit(1);
});
if !supported_models.contains(&model_name.as_str()) {
eprintln!(
"Error: Unsupported model '{}'. Supported models: {:?}",
model_name, supported_models
);
std::process::exit(1);
}
let onnx_path = format!("artifacts/{}_opset16.onnx", model_name);
let test_data_path = format!("artifacts/{}_test_data.pt", model_name);
// Tell Cargo to only rebuild if these files change
println!("cargo:rerun-if-changed={}", onnx_path);
println!("cargo:rerun-if-changed={}", test_data_path);
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-env-changed=YOLO_MODEL");
// Check if the ONNX model file exists
if !Path::new(&onnx_path).exists() {
eprintln!("Error: ONNX model file not found at '{}'", onnx_path);
eprintln!();
eprintln!(
"Please run the following command to download and prepare the {} model:",
model_name
);
eprintln!(" python get_model.py --model {}", model_name);
eprintln!();
eprintln!("Or if you prefer using uv:");
eprintln!(" uv run get_model.py --model {}", model_name);
eprintln!();
eprintln!("Available models: {}", supported_models.join(", "));
std::process::exit(1);
}
// Generate the model code from the ONNX file
ModelGen::new()
.input(&onnx_path)
.out_dir("model/")
.run_from_script();
// Write the model name to a file so main.rs can access it
let out_dir = env::var("OUT_DIR").unwrap();
let model_info_path = Path::new(&out_dir).join("model_info.rs");
// Generate the include path for the model
let model_include = format!(
"include!(concat!(env!(\"OUT_DIR\"), \"/model/{}_opset16.rs\"));",
model_name
);
fs::write(
model_info_path,
format!(
r#"pub const MODEL_NAME: &str = "{}";
pub const TEST_DATA_FILE: &str = "{}_test_data.pt";
// Include the generated model
pub mod yolo_model {{
{}
}}"#,
model_name, model_name, model_include
),
)
.expect("Failed to write model info");
}

View File

@@ -1,207 +0,0 @@
#!/usr/bin/env -S uv run --script
# /// script
# dependencies = [
# "onnx>=1.17.0",
# "onnxruntime>=1.18.0",
# "ultralytics>=8.3.0",
# "numpy",
# "pillow",
# "torch",
# ]
# ///
import os
import sys
import onnx
from onnx import shape_inference, version_converter
import numpy as np
from pathlib import Path
import argparse
# Supported YOLO models configuration
SUPPORTED_MODELS = {
'yolov5s': {'download_name': 'yolov5s.pt', 'display_name': 'YOLOv5s'},
'yolov8n': {'download_name': 'yolov8n.pt', 'display_name': 'YOLOv8n'},
'yolov8s': {'download_name': 'yolov8s.pt', 'display_name': 'YOLOv8s'},
'yolov10n': {'download_name': 'yolov10n.pt', 'display_name': 'YOLOv10n'},
'yolo11x': {'download_name': 'yolo11x.pt', 'display_name': 'YOLO11x'},
'yolo12x': {'download_name': 'yolo12x.pt', 'display_name': 'YOLO12x'},
}
def get_input_shape(model):
"""Extract input shape from ONNX model."""
input_info = model.graph.input[0]
shape = []
for dim in input_info.type.tensor_type.shape.dim:
if dim.HasField('dim_value'):
shape.append(dim.dim_value)
else:
shape.append(1) # Default to 1 for dynamic dimensions
# Ensure valid YOLO input shape
if len(shape) != 4 or shape[2] == 0 or shape[2] > 2000:
return [1, 3, 640, 640]
return shape
def download_and_convert_model(model_name, output_path):
"""Download YOLO model and export to ONNX format."""
from ultralytics import YOLO
model_config = SUPPORTED_MODELS[model_name]
display_name = model_config['display_name']
download_name = model_config['download_name']
print(f"Downloading {display_name} model...")
model = YOLO(download_name)
print("Exporting to ONNX format...")
model.export(format="onnx", simplify=True)
# Move exported file to artifacts
base_name = download_name.replace('.pt', '')
exported_file = Path(f"{base_name}.onnx")
if exported_file.exists():
exported_file.rename(output_path)
# Clean up PyTorch file
pt_file = Path(download_name)
if pt_file.exists():
pt_file.unlink()
if not output_path.exists():
raise FileNotFoundError(f"Failed to create ONNX file at {output_path}")
def process_model(input_path, output_path, target_opset=16):
"""Load, upgrade opset, and apply shape inference to model."""
print(f"Loading model from {input_path}...")
model = onnx.load(input_path)
# Check and upgrade opset if needed
current_opset = model.opset_import[0].version
if current_opset < target_opset:
print(f"Upgrading opset from {current_opset} to {target_opset}...")
model = version_converter.convert_version(model, target_opset)
# Apply shape inference
print("Applying shape inference...")
model = shape_inference.infer_shapes(model)
# Save processed model
onnx.save(model, output_path)
print(f"✓ Processed model saved to: {output_path}")
return model
def generate_test_data(model_path, output_path, model_name):
"""Generate test input/output data and save as PyTorch tensors."""
import torch
import onnxruntime as ort
print("\nGenerating test data...")
# Load model to get input shape
model = onnx.load(model_path)
input_shape = get_input_shape(model)
print(f" Input shape: {input_shape}")
# Create reproducible test input
np.random.seed(42)
test_input = np.random.rand(*input_shape).astype(np.float32)
# Run inference to get output
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: test_input})
# Save as PyTorch tensors
test_data = {
'input': torch.from_numpy(test_input),
'output': torch.from_numpy(outputs[0])
}
torch.save(test_data, output_path)
print(f" ✓ Test data saved to: {output_path}")
print(f" Input shape: {test_input.shape}, Output shape: {outputs[0].shape}")
def main():
parser = argparse.ArgumentParser(description='YOLO Model Preparation Tool')
parser.add_argument('--model', type=str, default='yolov8n',
choices=list(SUPPORTED_MODELS.keys()),
help=f'YOLO model to download and prepare (default: yolov8n). Choices: {", ".join(SUPPORTED_MODELS.keys())}')
parser.add_argument('--list', action='store_true',
help='List all supported models')
args = parser.parse_args()
if args.list:
print("Supported YOLO models:")
for model_id, config in SUPPORTED_MODELS.items():
print(f" - {model_id:10s} ({config['display_name']})")
return
model_name = args.model
display_name = SUPPORTED_MODELS[model_name]['display_name']
print("=" * 60)
print(f"{display_name} Model Preparation Tool")
print("=" * 60)
# Setup paths
artifacts_dir = Path("artifacts")
artifacts_dir.mkdir(exist_ok=True)
original_path = artifacts_dir / f"{model_name}.onnx"
processed_path = artifacts_dir / f"{model_name}_opset16.onnx"
test_data_path = artifacts_dir / f"{model_name}_test_data.pt"
# Check if we already have everything
if processed_path.exists() and test_data_path.exists():
print(f"\n✓ All files already exist for {display_name}:")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print("\nNothing to do!")
return
# Download and convert if needed
if not original_path.exists() and not processed_path.exists():
print(f"\nStep 1: Downloading and converting {display_name} model...")
download_and_convert_model(model_name, original_path)
# Process model if needed
if not processed_path.exists():
print("\nStep 2: Processing model...")
process_model(original_path, processed_path, target_opset=16)
# Clean up original if we have the processed version
if original_path.exists():
original_path.unlink()
# Generate test data if needed
if not test_data_path.exists():
print("\nStep 3: Generating test data...")
generate_test_data(processed_path, test_data_path, model_name)
print("\n" + "=" * 60)
print(f"{display_name} model preparation completed!")
print(f" Model: {processed_path}")
print(f" Test data: {test_data_path}")
print("=" * 60)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n⚠ Operation cancelled by user.")
sys.exit(1)
except Exception as e:
print(f"\n✗ Error: {str(e)}")
sys.exit(1)

View File

@@ -1,196 +0,0 @@
extern crate alloc;
use burn::module::{Initializer, Param};
use burn::prelude::*;
use burn_store::{ModuleSnapshot, PytorchStore};
use std::path::Path;
use std::time::Instant;
#[cfg(feature = "wgpu")]
pub type MyBackend = burn::backend::Wgpu;
#[cfg(feature = "ndarray")]
pub type MyBackend = burn::backend::NdArray<f32>;
#[cfg(feature = "tch")]
pub type MyBackend = burn::backend::LibTorch<f32>;
#[cfg(feature = "metal")]
pub type MyBackend = burn::backend::Metal;
// Import model info generated by build.rs (includes the yolo_model module)
include!(concat!(env!("OUT_DIR"), "/model_info.rs"));
// Use the yolo_model module from model_info.rs
use yolo_model::Model;
#[derive(Debug, Module)]
struct TestData<B: Backend> {
input: Param<Tensor<B, 4>>,
output: Param<Tensor<B, 3>>,
}
impl<B: Backend> TestData<B> {
fn new(device: &B::Device) -> Self {
// YOLO: input 640x640, output [1, 84, 8400]
Self {
input: Initializer::Zeros.init([1, 3, 640, 640], device),
output: Initializer::Zeros.init([1, 84, 8400], device),
}
}
}
fn get_model_display_name(model_name: &str) -> &str {
match model_name {
"yolov5s" => "YOLOv5s",
"yolov8n" => "YOLOv8n",
"yolov8s" => "YOLOv8s",
"yolov10n" => "YOLOv10n",
"yolo11x" => "YOLO11x",
"yolo12x" => "YOLO12x",
_ => model_name,
}
}
fn main() {
// MODEL_NAME is set at build time from YOLO_MODEL env var
let model_name = MODEL_NAME;
let display_name = get_model_display_name(model_name);
println!("========================================");
println!("{} Burn Model Test", display_name);
println!("========================================\n");
// Check if artifacts exist
let artifacts_dir = Path::new("artifacts");
if !artifacts_dir.exists() {
eprintln!("Error: artifacts directory not found!");
eprintln!("Please run get_model.py first to download the model and test data.");
eprintln!("Example: uv run get_model.py --model {}", model_name);
std::process::exit(1);
}
// Check if model files exist for this specific model
let model_file = artifacts_dir.join(format!("{}_opset16.onnx", model_name));
let test_data_file = artifacts_dir.join(format!("{}_test_data.pt", model_name));
if !model_file.exists() || !test_data_file.exists() {
eprintln!("Error: Model files not found for {}!", display_name);
eprintln!("Please run: uv run get_model.py --model {}", model_name);
eprintln!();
eprintln!("Available models:");
eprintln!(" - yolov5s");
eprintln!(" - yolov8n");
eprintln!(" - yolov8s");
eprintln!(" - yolov10n");
eprintln!(" - yolo11x");
eprintln!(" - yolo12x");
std::process::exit(1);
}
// Initialize the model (without weights for now)
println!("Initializing {} model...", display_name);
let start = Instant::now();
let device = Default::default();
let model: Model<MyBackend> = Model::default();
let init_time = start.elapsed();
println!(" Model initialized in {:.2?}", init_time);
// Save model structure to file
let model_txt_path = artifacts_dir.join(format!("{}_model.txt", model_name));
println!(
"\nSaving model structure to {}...",
model_txt_path.display()
);
let model_str = format!("{}", model);
std::fs::write(&model_txt_path, &model_str).expect("Failed to write model structure to file");
println!(" Model structure saved");
// Load test data from PyTorch file
println!("\nLoading test data from {}...", test_data_file.display());
let start = Instant::now();
let mut test_data = TestData::<MyBackend>::new(&device);
let mut store = PytorchStore::from_file(&test_data_file);
test_data.load_from(&mut store).expect("Failed to load test data");
let load_time = start.elapsed();
println!(" Data loaded in {:.2?}", load_time);
// Get the input tensor from test data
let input = test_data.input.val();
let input_shape = input.shape();
println!(" Loaded input tensor with shape: {:?}", input_shape.dims);
// Get the reference output from test data
let reference_output = test_data.output.val();
let reference_shape = reference_output.shape();
println!(
" Loaded reference output with shape: {:?}",
reference_shape.dims
);
// Run inference with the loaded input
println!("\nRunning model inference with test input...");
let start = Instant::now();
let output = model.forward(input);
let inference_time = start.elapsed();
println!(" Inference completed in {:.2?}", inference_time);
// Display output shape
let shape = output.shape();
println!("\n Model output shape: {:?}", shape.dims);
// Verify expected output shape (most YOLO models use [1, 84, 8400])
let expected_shape = [1, 84, 8400];
if shape.dims == expected_shape {
println!(" ✓ Output shape matches expected: {:?}", expected_shape);
} else {
println!(
" ⚠ Note: Shape is {:?} (expected {:?} for most YOLO models)",
shape.dims, expected_shape
);
}
// Compare outputs
println!("\nComparing model output with reference data...");
// Check if outputs are close
if output
.clone()
.all_close(reference_output.clone(), Some(1e-4), Some(1e-4))
{
println!(" ✓ Model output matches reference data within tolerance (1e-4)!");
} else {
println!(" ⚠ Model output differs from reference data!");
// Calculate and display the difference statistics
let diff = output.clone() - reference_output.clone();
let abs_diff = diff.abs();
let max_diff = abs_diff.clone().max().into_scalar();
let mean_diff = abs_diff.mean().into_scalar();
println!(" Maximum absolute difference: {:.6}", max_diff);
println!(" Mean absolute difference: {:.6}", mean_diff);
// Show some sample values for debugging
println!("\n Sample values comparison (first 5 elements):");
let output_flat = output.clone().flatten::<1>(0, 2);
let reference_flat = reference_output.clone().flatten::<1>(0, 2);
for i in 0..5.min(output_flat.dims()[0]) {
let model_val: f32 = output_flat.clone().slice(s![i..i + 1]).into_scalar();
let ref_val: f32 = reference_flat.clone().slice(s![i..i + 1]).into_scalar();
println!(
" [{}] Model: {:.6}, Reference: {:.6}, Diff: {:.6}",
i,
model_val,
ref_val,
(model_val - ref_val).abs()
);
}
}
println!("\n========================================");
println!("Model test completed!");
println!("========================================");
}

View File

@@ -1,10 +0,0 @@
# python generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# venv
.venv

View File

@@ -1,22 +0,0 @@
[package]
name = "onnx-tests"
version.workspace = true
edition.workspace = true
license.workspace = true
[features]
default = ["test-ndarray"]
test-wgpu = ["burn/wgpu"]
test-ndarray = ["burn/ndarray"]
test-tch = ["burn/tch"]
test-metal = ["burn/metal"]
test-candle = ["burn/candle"]
[dev-dependencies]
burn = { path = "../../burn" }
burn-store = { path = "../../burn-store", features = ["burnpack"] }
serde = { workspace = true }
float-cmp = { workspace = true }
[build-dependencies]
burn-onnx = { path = "../" }

View File

@@ -1,312 +0,0 @@
# ONNX Tests
This crate contains ONNX models used for testing the conversion process from ONNX to Burn source
code through the `burn-onnx` crate. These tests are designed as end-to-end tests, ensuring that
ONNX models are accurately converted into Burn source code that compiles without errors and produces
the same outputs as the original ONNX model.
## Directory Structure
- `tests/<op_name>/`: Each operator or model has its own directory
- `tests/<op_name>/<op_name>.py`: Python script that generates the ONNX model
- `tests/<op_name>/<op_name>.onnx`: Generated ONNX model
- `tests/<op_name>/mod.rs`: Test implementation for the specific operator
- `tests/test_mod.rs`: Main test file that integrates all operator tests
- `build.rs`: Build script that generates ONNX models before running tests
## Setting Up Your Python Environment
### Using uv (Recommended)
You can use [`uv`](https://docs.astral.sh/uv/) to set up a Python environment with the necessary
dependencies:
```sh
cd crates/burn-onnx/onnx-tests
uv sync # or uv sync -f
```
This will create a `.venv` directory with all the required dependencies.
### Manual Setup
If you prefer to set up manually, you need to install the following packages:
```sh
pip install onnx==1.15.0 torch==2.1.1
```
Additional dependencies are specified in `requirements.lock`.
## Creating a Test for a New Operator
There are two main approaches to generating ONNX files for testing:
1. **Exporting a model from PyTorch** (most common)
2. **Constructing an ONNX graph directly** (for specific cases)
### 1. Create the Python Script
Create a new directory and Python script:
```sh
mkdir -p tests/my_new_op
touch tests/my_new_op/my_new_op.py
```
#### Approach 1: Exporting a PyTorch Model to ONNX
Your script should:
- Import the necessary PyTorch modules
- Define a model that uses your operator
- Generate test inputs
- Export the model to ONNX format
- Run the model on test inputs and print the output
Example structure:
```python
import torch
import torch.nn as nn
import torch.onnx
# Define a simple model that uses your operator
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# ...
def forward(self, x):
# Use your operator here
return my_operation(x)
# Create an instance of the model
model = MyModel()
# Generate test input
input_tensor = torch.randn(1, 3, 224, 224)
# Export the model to ONNX
torch.onnx.export(
model,
input_tensor,
"tests/my_new_op/my_new_op.onnx",
opset_version=16,
input_names=["input"],
output_names=["output"],
do_constant_folding=False # Set to False if you want to preserve specific operators
)
# Run the model with the test input and print output for test verification
output = model(input_tensor)
print("Input:", input_tensor)
print("Output:", output)
```
#### Approach 2: Constructing an ONNX Graph Directly
For some test cases, you may want to construct the ONNX graph directly using the ONNX Python API.
This is particularly useful when:
- You need precise control over operator attributes
- You're testing operators that are difficult to trigger through PyTorch models
- You want to test specific graph structures
Example (see `tests/gather/gather_1d_idx.py` for a complete example):
```python
import numpy as np
import onnx
from onnx import TensorProto, helper
# Create inputs
data = np.random.randn(5, 5, 5).astype(np.float32)
indices = np.array([0, 2, 4], dtype=np.int64)
# Create node
node = helper.make_node(
"Gather",
inputs=["data", "indices"],
outputs=["output"],
axis=1
)
# Create input tensors
data_tensor = helper.make_tensor_value_info("data", TensorProto.FLOAT, data.shape)
indices_tensor = helper.make_tensor_value_info("indices", TensorProto.INT64, indices.shape)
# Create output tensor
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [5, 3, 5])
# Create graph and model
graph = helper.make_graph(
[node],
"gather-model",
[data_tensor, indices_tensor],
[output_tensor],
initializer=[]
)
model = helper.make_model(graph)
onnx.save(model, "tests/my_new_op/my_new_op.onnx")
# For test verification, print input and expected output
print("Data:", data)
print("Indices:", indices)
print("Expected output:", np.take(data, indices, axis=1))
```
### 2. Add the Build Step
Update `build.rs` to include your new model.
### 3. Create a mod.rs Test File
Create a test module file in your operator directory:
```sh
touch tests/my_new_op/mod.rs
```
Implement the test for your operator in this file:
```rust
use super::test_record_type::TestRecordType;
use burn_onnx::OnnxModel;
#[test]
fn test_my_new_op() {
let model = OnnxModel::read("tests/my_new_op/my_new_op.onnx").unwrap();
let record = model.into_record::<TestRecordType>();
// Implement test logic and assertions here
}
```
Your test will be automatically included in the main test suite through `tests/test_mod.rs`.
## Best Practices for ONNX Testing
### Model Generation
1. **Keep Models Simple**: Focus on testing a single operator or a small group of related operators.
2. **Control Randomness**: Use fixed seeds in your Python scripts to ensure reproducible results:
```python
torch.manual_seed(42)
```
3. **Print Test Values**: Always print your input and output tensors in the Python script for
reference.
4. **Verify Operators**: Use [Netron](https://github.com/lutzroeder/netron) to verify your ONNX
model contains the expected operators.
5. **Handle Constant Folding**: If PyTorch is optimizing away your operators, use:
```python
torch.onnx.export(..., do_constant_folding=False)
```
### Test Implementation
1. **Test Multiple Cases**: Include tests for different input shapes, data types, and parameter
combinations.
2. **Edge Cases**: Test edge cases like empty tensors, single-element tensors, or very large
tensors.
3. **Parameter Variations**: If your operator has configurable parameters, test different parameter
values.
4. **Numerical Precision**: Use appropriate tolerance levels based on operation sensitivity.
5. **Error Cases**: Test that invalid inputs are properly handled and appropriate errors are raised.
## Running Tests
### Default Backend
Run all tests with:
```sh
cargo test
```
This command runs all tests using the default backend: `burn::backend::NdArray<f32>`.
### Testing with Different Backends
You can test with different Burn backends by using feature flags:
#### WGPU Backend
```sh
cargo test --features test-wgpu
```
Uses `burn::backend::Wgpu` for GPU-accelerated computation.
#### LibTorch Backend
```sh
cargo test --features test-tch
```
Uses `burn::backend::LibTorch<f32>` for Torch backend integration.
#### NdArray Backend (Explicit)
```sh
cargo test --features test-ndarray
```
Explicitly uses `burn::backend::NdArray<f32>` (same as default).
### Running Specific Tests
Run tests for a specific operator:
```sh
cargo test --test test_mod my_new_op::test_my_new_op
```
Run a specific test with a selected backend:
```sh
cargo test --test test_mod my_new_op::test_my_new_op --features test-wgpu
```
### Supported Backends
- `burn::backend::NdArray<f32>` (default) - CPU-based computation using ndarray
- `burn::backend::Wgpu` - GPU-accelerated computation using WebGPU
- `burn::backend::LibTorch<f32>` - Torch backend integration
**Note:** Only one backend feature should be enabled at a time. The backend selection uses
conditional compilation with the following priority:
1. `test-wgpu` (highest priority)
2. `test-tch`
3. `test-ndarray` (default when no other backend is selected)
## Debugging Failed Tests
If a test fails, you can:
1. **Inspect ONNX Model**: Use Netron to visualize the model structure.
2. **Check Intermediate Values**: Add print statements in your Python script to see intermediate
tensor values.
3. **Generate Rust Code**: Use the `burn-onnx` CLI to generate Rust code and inspect it:
```sh
cargo run -p burn-onnx -- tests/my_new_op/my_new_op.onnx ./out
```
4. **Trace Through Conversion**: Add debug logging in your implementation to see where things might
be going wrong.
5. **Numerical Issues**: If values are close but not equal, it might be a numerical precision issue.
Try adjusting tolerance.

View File

@@ -1,387 +0,0 @@
use burn_onnx::ModelGen;
fn main() {
// Re-run this build script if the onnx-tests directory changes.
println!("cargo:rerun-if-changed=tests");
// Add onnx models.
// All models are now saved in burnpack format (.bpk files)
ModelGen::new()
.input("tests/abs/abs.onnx")
.input("tests/add/add.onnx")
.input("tests/add/add_shape.onnx")
.input("tests/add/add_broadcast.onnx")
.input("tests/initializer_to_const/initializer_to_const.onnx")
.input("tests/and/and.onnx")
.input("tests/and/and_broadcast.onnx")
.input("tests/and/and_scalar.onnx")
.input("tests/add/add_int.onnx")
.input("tests/add/add_shape_tensor.onnx")
.input("tests/add/add_argmax_with_shape.onnx")
.input("tests/argmax/argmax.onnx")
.input("tests/argmax/argmax_both_keepdims.onnx")
.input("tests/argmax/argmax_1d.onnx")
.input("tests/argmin/argmin.onnx")
.input("tests/argmin/argmin_both_keepdims.onnx")
.input("tests/argmin/argmin_1d.onnx")
.input("tests/attention/attention_4d.onnx")
.input("tests/attention/attention_3d.onnx")
.input("tests/attention/attention_attn_mask_bool.onnx")
.input("tests/attention/attention_attn_mask_int.onnx")
.input("tests/attention/attention_attn_mask_float.onnx")
.input("tests/attention/attention_softcap.onnx")
.input("tests/attention/attention_cache.onnx")
.input("tests/attention/attention_custom_scale.onnx")
.input("tests/attention/attention_is_causal.onnx")
.input("tests/attention/attention_qk_output_0.onnx")
.input("tests/attention/attention_qk_output_1.onnx")
.input("tests/attention/attention_qk_output_2.onnx")
.input("tests/attention/attention_qk_output_3.onnx")
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool1d_ceil_mode/avg_pool1d_ceil_mode.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
.input("tests/avg_pool2d_ceil_mode/avg_pool2d_ceil_mode.onnx")
.input("tests/batch_norm/batch_norm.onnx")
.input("tests/bitshift/bitshift_left.onnx")
.input("tests/bitshift/bitshift_left_scalar.onnx")
.input("tests/bitshift/scalar_bitshift_left.onnx")
.input("tests/bitshift/scalar_bitshift_left_scalar.onnx")
.input("tests/bitshift/bitshift_right.onnx")
.input("tests/bitshift/bitshift_right_scalar.onnx")
.input("tests/bitshift/scalar_bitshift_right.onnx")
.input("tests/bitshift/scalar_bitshift_right_scalar.onnx")
.input("tests/bitwise_and/bitwise_and.onnx")
.input("tests/bitwise_and/bitwise_and_scalar.onnx")
.input("tests/bitwise_and/scalar_bitwise_and.onnx")
.input("tests/bitwise_and/scalar_bitwise_and_scalar.onnx")
.input("tests/bitwise_not/bitwise_not.onnx")
.input("tests/bitwise_or/bitwise_or.onnx")
.input("tests/bitwise_or/bitwise_or_scalar.onnx")
.input("tests/bitwise_or/scalar_bitwise_or.onnx")
.input("tests/bitwise_or/scalar_bitwise_or_scalar.onnx")
.input("tests/bitwise_xor/bitwise_xor.onnx")
.input("tests/bitwise_xor/bitwise_xor_scalar.onnx")
.input("tests/bitwise_xor/scalar_bitwise_xor.onnx")
.input("tests/bitwise_xor/scalar_bitwise_xor_scalar.onnx")
.input("tests/bernoulli/bernoulli.onnx")
.input("tests/cast/cast.onnx")
.input("tests/cast/cast_shape.onnx")
.input("tests/cast/cast_shape_to_float.onnx")
.input("tests/cast/cast_shape_to_bool.onnx")
.input("tests/ceil/ceil.onnx")
.input("tests/clip/clip.onnx")
.input("tests/concat/concat.onnx")
.input("tests/concat/concat_shape.onnx")
.input("tests/concat/concat_shape_with_constant.onnx")
.input("tests/concat/concat_mixed_single_element.onnx")
.input("tests/concat/concat_mixed_three_elements.onnx")
.input("tests/concat/concat_multiple_mixed.onnx")
.input("tests/concat/concat_with_constants.onnx")
.input("tests/constant/constant_f32.onnx")
.input("tests/constant/constant_f64.onnx")
.input("tests/constant/constant_i32.onnx")
.input("tests/constant/constant_i64.onnx")
.input("tests/constant/constant_bool.onnx")
.input("tests/constant/constant_shape.onnx")
.input("tests/constant/constant_tensor_f32.onnx")
.input("tests/constant/constant_tensor_i32.onnx")
.input("tests/constant/constant_tensor_bool.onnx")
.input("tests/constant/constant_empty_tensor_f32.onnx")
.input("tests/constant/rank_inference_propagation.onnx")
.input("tests/constant/shape_binary_ops_with_constant.onnx")
.input("tests/constant_of_shape/constant_of_shape.onnx")
.input("tests/constant_of_shape/constant_of_shape_full_like.onnx")
.input("tests/constant_of_shape/constant_of_shape_scalar.onnx")
.input("tests/constant_of_shape/constant_of_shape_scalar_custom_value.onnx")
.input("tests/constant_of_shape/constant_of_shape_tensor.onnx")
.input("tests/constant_of_shape/constant_of_shape_shape_optimization.onnx")
.input("tests/constant_of_shape/constant_of_shape_with_constant_input.onnx")
.input("tests/constant_lifting_multiple/constant_lifting_multiple.onnx")
.input("tests/constant_lifting_multiple/constant_reused.onnx")
.input("tests/conv1d/conv1d.onnx")
.input("tests/conv2d/conv2d.onnx")
.input("tests/conv3d/conv3d.onnx")
.input("tests/conv_transpose1d/conv_transpose1d.onnx")
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
.input("tests/cos/cos.onnx")
.input("tests/cosh/cosh.onnx")
.input("tests/cumsum/cumsum.onnx")
.input("tests/cumsum/cumsum_exclusive.onnx")
.input("tests/cumsum/cumsum_reverse.onnx")
.input("tests/cumsum/cumsum_exclusive_reverse.onnx")
.input("tests/cumsum/cumsum_2d.onnx")
.input("tests/cumsum/cumsum_runtime_axis.onnx")
.input("tests/cumsum/cumsum_single_element.onnx")
.input("tests/cumsum/cumsum_exclusive_single.onnx")
.input("tests/depth_to_space/depth_to_space_dcr.onnx")
.input("tests/depth_to_space/depth_to_space_crd.onnx")
.input("tests/div/div.onnx")
.input("tests/div/div_shape.onnx")
.input("tests/div/div_shape_tensor.onnx")
.input("tests/div/div_broadcast.onnx")
.input("tests/mod/modulo.onnx")
.input("tests/mod/mod_scalar.onnx")
.input("tests/mod/mod_remainder.onnx")
.input("tests/mod/mod_fmod.onnx")
.input("tests/mod/mod_broadcast_fixed.onnx")
.input("tests/mod/mod_broadcast_remainder_fixed.onnx")
.input("tests/dropout/dropout.onnx")
.input("tests/empty_graph/empty_graph_scalar.onnx")
.input("tests/empty_graph/empty_graph_scalar_int.onnx")
.input("tests/empty_graph/empty_graph_shape.onnx")
.input("tests/empty_graph/empty_graph_tensor.onnx")
.input("tests/empty_graph/empty_graph_multiple.onnx")
.input("tests/equal/equal.onnx")
.input("tests/equal/equal_shape.onnx")
.input("tests/equal/equal_two_shapes.onnx")
.input("tests/erf/erf.onnx")
.input("tests/exp/exp.onnx")
.input("tests/expand/expand.onnx")
.input("tests/expand/expand_tensor.onnx")
.input("tests/expand/expand_shape.onnx")
.input("tests/expand/expand_with_where_shape.onnx")
.input("tests/expand/expand_max_semantics.onnx")
.input("tests/eye_like/eye_like.onnx")
.input("tests/eye_like/eye_like_k1.onnx")
.input("tests/eye_like/eye_like_int.onnx")
.input("tests/eye_like/eye_like_k_minus1.onnx")
.input("tests/eye_like/eye_like_float64.onnx")
.input("tests/eye_like/eye_like_int32.onnx")
.input("tests/eye_like/eye_like_bool.onnx")
.input("tests/eye_like/eye_like_large_k.onnx")
.input("tests/eye_like/eye_like_1x1.onnx")
.input("tests/eye_like/eye_like_wide.onnx")
.input("tests/eye_like/eye_like_neg_large_k.onnx")
.input("tests/nonzero/nonzero_float32.onnx")
.input("tests/nonzero/nonzero_int64.onnx")
.input("tests/nonzero/nonzero_bool.onnx")
.input("tests/nonzero/nonzero_1d.onnx")
.input("tests/nonzero/nonzero_3d.onnx")
.input("tests/nonzero/nonzero_empty.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/flatten/flatten_2d.onnx")
.input("tests/floor/floor.onnx")
.input("tests/gather/gather_1d_idx.onnx")
.input("tests/gather/gather_2d_idx.onnx")
.input("tests/gather/gather_scalar.onnx")
.input("tests/gather/gather_constant_2d_indices.onnx")
.input("tests/gather/gather_static_shape_indices.onnx")
.input("tests/gather/gather_shape.onnx")
.input("tests/gather/gather_with_shape_indices.onnx")
.input("tests/gather/gather_scalar_out.onnx")
.input("tests/gather/gather_scalar_input.onnx")
.input("tests/gather/gather_negative_idx.onnx")
.input("tests/gather_elements/gather_elements.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/gemm/gemm.onnx")
.input("tests/gemm/gemm_non_unit_alpha_beta.onnx")
.input("tests/gemm/gemm_no_c.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/graph_multiple_output_tracking/graph_multiple_output_tracking.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater/greater_scalar.onnx")
.input("tests/greater/greater_broadcast.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
.input("tests/greater_or_equal/greater_or_equal_scalar.onnx")
.input("tests/greater_or_equal/greater_or_equal_broadcast.onnx")
.input("tests/grid_sample/grid_sample.onnx")
.input("tests/grid_sample/grid_sample_nearest.onnx")
.input("tests/group_norm/group_norm.onnx")
.input("tests/hard_sigmoid/hard_sigmoid.onnx")
.input("tests/hard_swish/hard_swish.onnx")
.input("tests/identity/identity_constant.onnx")
.input("tests/identity/identity_passthrough.onnx")
.input("tests/identity/identity_chain.onnx")
.input("tests/identity/identity_only.onnx")
.input("tests/instance_norm1d/instance_norm1d.onnx")
.input("tests/instance_norm2d/instance_norm2d.onnx")
.input("tests/instance_norm3d/instance_norm3d.onnx")
.input("tests/is_inf/is_inf.onnx")
.input("tests/is_inf/is_inf_scalar.onnx")
.input("tests/is_inf/is_inf_neg_only.onnx")
.input("tests/is_inf/is_inf_pos_only.onnx")
.input("tests/is_inf/is_inf_none.onnx")
.input("tests/is_nan/is_nan.onnx")
.input("tests/is_nan/is_nan_scalar.onnx")
.input("tests/layer_norm/layer_norm.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/less/less.onnx")
.input("tests/less/less_scalar.onnx")
.input("tests/less/less_broadcast.onnx")
.input("tests/less_or_equal/less_or_equal.onnx")
.input("tests/less_or_equal/less_or_equal_scalar.onnx")
.input("tests/less_or_equal/less_or_equal_broadcast.onnx")
.input("tests/linear/linear.onnx")
.input("tests/log/log.onnx")
.input("tests/lstm/lstm.onnx")
.input("tests/lstm/lstm_bidirectional.onnx")
.input("tests/lstm/lstm_reverse.onnx")
.input("tests/lstm/lstm_with_initial_state.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/where_op/where_op.onnx")
.input("tests/where_op/where_op_broadcast.onnx")
.input("tests/where_op/where_op_scalar_x.onnx")
.input("tests/where_op/where_op_scalar_y.onnx")
.input("tests/where_op/where_op_all_scalar.onnx")
.input("tests/where_op/where_shape_all_shapes.onnx")
.input("tests/where_op/where_shape_scalar_cond.onnx")
.input("tests/where_op/where_shapes_from_inputs.onnx")
.input("tests/where_op/where_static_shape.onnx")
.input("tests/matmul/matmul.onnx")
.input("tests/matmulinteger/matmulinteger.onnx")
.input("tests/matmulinteger/matmulinteger_ranks.onnx")
.input("tests/matmul/matmul_ranks.onnx")
.input("tests/max/max.onnx")
.input("tests/maxpool1d/maxpool1d.onnx")
.input("tests/maxpool1d_ceil_mode/maxpool1d_ceil_mode.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/maxpool2d_ceil_mode/maxpool2d_ceil_mode.onnx")
.input("tests/min/min.onnx")
.input("tests/mean/mean.onnx")
.input("tests/mul/mul.onnx")
.input("tests/mul/mul_shape.onnx")
.input("tests/mul/mul_shape_tensor.onnx")
.input("tests/mul/mul_broadcast.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/one_hot/one_hot.onnx")
.input("tests/or/or.onnx")
.input("tests/or/or_scalar.onnx")
.input("tests/or/or_broadcast.onnx")
.input("tests/pad/pad.onnx")
.input("tests/pad/pad_reflect.onnx")
.input("tests/pad/pad_edge.onnx")
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/prelu/prelu.onnx")
.input("tests/prelu/prelu_with_channel_slope.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/random_normal_like/random_normal_like.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/random_uniform_like/random_uniform_like.onnx")
.input("tests/range/range.onnx")
.input("tests/range/range_static.onnx")
.input("tests/range/range_mixed.onnx")
.input("tests/range/range_runtime.onnx")
.input("tests/recip/recip.onnx")
.input("tests/reduce/reduce_max.onnx")
.input("tests/reduce/reduce_max_bool.onnx")
.input("tests/reduce/reduce_mean.onnx")
.input("tests/reduce/reduce_mean_partial_shape.onnx")
.input("tests/reduce/reduce_min.onnx")
.input("tests/reduce/reduce_min_bool.onnx")
.input("tests/reduce/reduce_prod.onnx")
.input("tests/reduce/reduce_sum.onnx")
.input("tests/reduce/reduce_sum_square.onnx")
.input("tests/reduce/reduce_l1.onnx")
.input("tests/reduce/reduce_l2.onnx")
.input("tests/reduce/reduce_log_sum.onnx")
.input("tests/reduce/reduce_log_sum_exp.onnx")
.input("tests/relu/relu.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/reshape/reshape_with_1d_tensor.onnx")
.input("tests/reshape/reshape_with_shape.onnx")
.input("tests/reshape/reshape_to_scalar.onnx")
.input("tests/reshape/reshape_3d_to_scalar.onnx")
.input("tests/reshape/reshape_shape_to_shape.onnx")
.input("tests/reshape/reshape_shape_with_neg.onnx")
.input("tests/reshape/reshape_shape_partial.onnx")
.input("tests/reshape/reshape_scalar_to_scalar.onnx")
.input("tests/resize/resize_with_sizes.onnx")
.input("tests/resize/resize_1d_linear_scale.onnx")
.input("tests/resize/resize_1d_nearest_scale.onnx")
.input("tests/resize/resize_2d_bicubic_scale.onnx")
.input("tests/resize/resize_2d_bilinear_scale.onnx")
.input("tests/resize/resize_2d_nearest_scale.onnx")
.input("tests/resize/resize_with_shape.onnx")
.input("tests/resize/resize_with_sizes_tensor.onnx")
.input("tests/round/round.onnx")
.input("tests/shape/shape.onnx")
.input("tests/shape/shape_of_shape.onnx")
.input("tests/shape/shape_slice.onnx")
.input("tests/shape/shape_chain.onnx")
.input("tests/sigmoid/sigmoid.onnx")
.input("tests/sign/sign.onnx")
.input("tests/sin/sin.onnx")
.input("tests/sinh/sinh.onnx")
.input("tests/size/size.onnx")
.input("tests/slice/slice.onnx")
.input("tests/slice/slice_shape.onnx")
.input("tests/slice/slice_scalar.onnx")
.input("tests/slice/slice_mixed.onnx")
.input("tests/slice/slice_shape_gather.onnx")
.input("tests/slice/slice_shape_runtime.onnx")
.input("tests/slice/slice_shape_multi.onnx")
.input("tests/slice/slice_shape_negative.onnx")
.input("tests/slice/slice_shape_negative_range.onnx")
.input("tests/slice/slice_1d_tensor.onnx")
.input("tests/slice/slice_shape_start_tensor_end.onnx")
.input("tests/slice/slice_tensor_start_shape_end.onnx")
.input("tests/slice/slice_axes.onnx")
.input("tests/slice/slice_with_steps.onnx")
.input("tests/slice/slice_shape_with_steps.onnx")
.input("tests/slice/slice_empty.onnx")
.input("tests/softmax/softmax.onnx")
.input("tests/space_to_depth/space_to_depth.onnx")
.input("tests/sqrt/sqrt.onnx")
.input("tests/squeeze/squeeze_multiple.onnx")
.input("tests/squeeze/squeeze.onnx")
.input("tests/squeeze/squeeze_shape.onnx")
.input("tests/squeeze/squeeze_shape_noop.onnx")
.input("tests/squeeze/squeeze_scalar.onnx")
.input("tests/squeeze/squeeze_float.onnx")
.input("tests/squeeze/squeeze_tensor_to_scalar.onnx")
.input("tests/squeeze/squeeze_opset13_axes_input.onnx")
.input("tests/squeeze/squeeze_no_axes.onnx")
.input("tests/sub/sub.onnx")
.input("tests/sub/sub_shape.onnx")
.input("tests/sub/sub_broadcast.onnx")
.input("tests/sub/sub_int.onnx")
.input("tests/sub/sub_shape_tensor.onnx")
.input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx")
.input("tests/tan/tan.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/tile/tile.onnx")
.input("tests/topk/topk.onnx")
.input("tests/trilu/trilu_upper.onnx")
.input("tests/trilu/trilu_lower.onnx")
.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze_runtime_axes.onnx")
.input("tests/unsqueeze/unsqueeze_like.onnx")
.input("tests/unsqueeze/unsqueeze_int_to_shape.onnx")
.input("tests/unsqueeze/squeeze_unsqueeze_roundtrip.onnx")
.input("tests/split/split.onnx")
.input("tests/xor/xor.onnx")
.input("tests/xor/xor_scalar.onnx")
.input("tests/xor/xor_broadcast.onnx")
// If operator tests
.input("tests/if_op/if_conv2d.onnx")
.input("tests/if_op/if_linear.onnx")
.input("tests/if_op/nested_if.onnx")
// Loop operator tests
.input("tests/loop/loop_simple.onnx")
.input("tests/loop/loop_dynamic_cond.onnx")
.input("tests/loop/loop_multi_deps.onnx")
.input("tests/loop/loop_nested.onnx")
.input("tests/loop/loop_scan_outputs.onnx")
// Scan operator tests
.input("tests/scan/scan_cumsum.onnx")
.input("tests/scan/scan_reverse.onnx")
.input("tests/scan/scan_multi_state.onnx")
.input("tests/scan/scan_axis1.onnx")
// Subgraph tests: nested control flow and outer-scope references
.input("tests/subgraph/nested_if_loop_if.onnx")
.input("tests/subgraph/nested_if_loop_if_scan.onnx")
.input("tests/subgraph/outer_scope_ref.onnx")
.input("tests/subgraph/outer_scope_multi_var.onnx")
.input("tests/subgraph/outer_scope_loop.onnx")
.input("tests/subgraph/outer_scope_scan.onnx")
.input("tests/subgraph/outer_scope_constant.onnx")
.out_dir("model/")
.run_from_script();
// Note: Previous record type variants (NamedMpk, PrettyJson, Bincode, etc.)
// have been removed. All models now use burnpack format exclusively.
}

View File

@@ -1,26 +0,0 @@
[project]
name = "onnx-tests"
version = "0.1.0"
description = "project for testing ONNX support"
authors = []
dependencies = [
"torch>=2.3.1",
"onnx>=1.16.1",
"onnxruntime>=1.18.0",
]
readme = "README.md"
requires-python = ">= 3.8"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = []
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/onnx_tests"]

View File

@@ -1,75 +0,0 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
# with-sources: false
-e file:.
coloredlogs==15.0.1
# via onnxruntime
filelock==3.15.1
# via torch
flatbuffers==24.3.25
# via onnxruntime
fsspec==2024.6.0
# via torch
humanfriendly==10.0
# via coloredlogs
jinja2==3.1.4
# via torch
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.3
# via torch
numpy==1.26.4
# via onnx
# via onnxruntime
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.5.40
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
onnx==1.16.1
# via onnx-tests
onnxruntime==1.18.0
# via onnx-tests
packaging==24.1
# via onnxruntime
protobuf==5.27.1
# via onnx
# via onnxruntime
sympy==1.12.1
# via onnxruntime
# via torch
torch==2.3.1
# via onnx-tests
typing-extensions==4.12.2
# via torch

View File

@@ -1,75 +0,0 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
# with-sources: false
-e file:.
coloredlogs==15.0.1
# via onnxruntime
filelock==3.15.1
# via torch
flatbuffers==24.3.25
# via onnxruntime
fsspec==2024.6.0
# via torch
humanfriendly==10.0
# via coloredlogs
jinja2==3.1.4
# via torch
markupsafe==2.1.5
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.3
# via torch
numpy==1.26.4
# via onnx
# via onnxruntime
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.5.40
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
onnx==1.16.1
# via onnx-tests
onnxruntime==1.18.0
# via onnx-tests
packaging==24.1
# via onnxruntime
protobuf==5.27.1
# via onnx
# via onnxruntime
sympy==1.12.1
# via onnxruntime
# via torch
torch==2.3.1
# via onnx-tests
typing-extensions==4.12.2
# via torch

View File

@@ -1 +0,0 @@
#![no_std]

View File

@@ -1,16 +0,0 @@
pytorch2.8.0:m

onnx::Abs_01/Abs"Abs
main_graphZ%
onnx::Abs_0




b
1




B

View File

@@ -1,40 +0,0 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/abs/abs.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.abs(x)
def main():
# Set random seed for reproducibility
torch.manual_seed(0)
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "abs.onnx"
test_input = torch.tensor([[[[-1.0, -4.0, 9.0, -25.0]]]], device=device)
torch.onnx.export(model, (test_input), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

View File

@@ -1,27 +0,0 @@
// Import the shared macro
use crate::include_models;
include_models!(abs);
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::{Tensor, TensorData, Tolerance, ops::FloatElem};
use crate::backend::TestBackend;
type FT = FloatElem<TestBackend>;
#[test]
fn abs() {
let device = Default::default();
let model: abs::Model<TestBackend> = abs::Model::new(&device);
let input = Tensor::<TestBackend, 4>::from_floats([[[[-1.0, -4.0, 9.0, -25.0]]]], &device);
let output = model.forward(input);
let expected = TensorData::from([[[[1.0f32, 4.0, 9.0, 25.0]]]]);
output
.to_data()
.assert_approx_eq::<FT>(&expected, Tolerance::default());
}
}

View File

@@ -1,57 +0,0 @@
#!/usr/bin/env python3
# used to generate model: onnx-tests/tests/add/add.onnx
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
# Declare a constant float tensor with ones
self.a = torch.ones(1, 1, 1, 4)
# Declare a scalar
self.b = 5.0
super(Model, self).__init__()
def forward(self, x, k):
# Add a tensor input and a constant tensor
x = x + self.a
# Add a scalar constant and a scalar input
d = self.b + k
# Add a tensor and a scalar
x = x + d
return x
def main():
# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "add.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)
scalar = 2.0
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
verbose=False, opset_version=16)
print("Finished exporting model to {}".format(onnx_name))
# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
print("Test input data: {}, {}".format(test_input, scalar))
output = model.forward(test_input, scalar)
print("Test output data: {}".format(output))
if __name__ == '__main__':
main()

Some files were not shown because too many files have changed in this diff Show More