mirror of
https://github.com/tracel-ai/burn.git
synced 2026-05-31 19:49:48 +09:00
* Move ONNX import into burn-onnx crate * Update publish * Update burn-import -> burn-onnx * Fix clippy warnings that are no longer allowed * Allow unused * Update contributor book references to burn-onnx Update all burn-import path references in the ONNX development guide to point to the new burn-onnx crate location. * Remove ONNX integration steps from burn op guide Deleted the section detailing how to add a new operation to burn-onnx, including ONNX IR and code generation mapping steps. This streamlines the guide and removes outdated or redundant ONNX-specific instructions. * Update onnx-ir references from burn-import to burn-onnx Update documentation and code comments to reference the new burn-onnx crate instead of burn-import. * Update ONNX test producer name to burn-onnx-test Update producer_name metadata in Python test scripts from "burn-import-test" to "burn-onnx-test" for consistency. * Undo ONNX file changes
40 lines
937 B
Python
40 lines
937 B
Python
#!/usr/bin/env python3
|
|
|
|
# used to generate model: onnx-tests/tests/abs/abs.onnx
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super(Model, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.abs(x)
|
|
|
|
|
|
def main():
|
|
# Set random seed for reproducibility
|
|
torch.manual_seed(0)
|
|
|
|
# Export to onnx
|
|
model = Model()
|
|
model.eval()
|
|
device = torch.device("cpu")
|
|
onnx_name = "abs.onnx"
|
|
test_input = torch.tensor([[[[-1.0, -4.0, 9.0, -25.0]]]], device=device)
|
|
|
|
torch.onnx.export(model, (test_input), onnx_name,
|
|
verbose=False, opset_version=16)
|
|
|
|
print("Finished exporting model to {}".format(onnx_name))
|
|
|
|
# Output some test data for use in the test
|
|
print("Test input data: {}".format(test_input))
|
|
output = model.forward(test_input)
|
|
print("Test output data: {}".format(output))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |