Files
openmm-torch/python/tests/TestSerializeTorchForce.py
Raul 769302afd3 Add a constructor to TorchForce that takes a torch::jit::Module (#97)
* Add version number as a member to TorchForceProxy

* Encode the model file contents when serializing TorchForce

* Add tests for new TorchForce serialization

* Fix test not finding Python executable

* Format include directives correctly

* Hardcode TorchForceProxy version number

* Fix formatting issues

* Move Python serialization test to the correct place

* Make function encodeFromFileName static

* Update serialization python test to correctly remove temporary files after executing

* Use the base64 encoding capabilities of openssl to serialize model file

* Update TorchForce serializer

* Add a constructor to TorchForce that takes a torch::jit::Module.
 TorchForce(string fileName) is implemented by delegating to the new
 constructor.
 Update serialization test accordingly to compare the module file name
 and the module itself.

* Remove unnecessary include

* Change i_file to file in TorchForce constructor

* Add swig typemaps to new TorchForce constructor

* Add setup.py as a dependency for the PythonInstall CMake rule

* Fix swig out typemap for torch::jit::Module
 Now it is possible to call getModule() on a TorchForce object from
 Python, which will return a module of the same type as, for instance, torch.jit.load()

* Remove commented line in CMakeLists.txt

* Remove unnecessary dependency in setup.py

* Add more tests for new constructor

* Add some comments for the new constructor

* Updates to TorchForce serialization

* Use hex encoding instead of base64 for serialization.
SSL no longer a direct dependency.

* Remove unnecessary header

* Update Python serialization test

* Minor changes

* Improve temporary path handling in python serialization tests

* More informative exception when failing to serialize TorchForce

* Remove unnecessary check in TorchForce serialization

* Changes to C++ serialization tests

* Changes to C++ serialization tests
2023-02-10 14:58:03 -08:00

49 lines
1.5 KiB
Python

import torch
import shutil
import pytest
from openmm import XmlSerializer, OpenMMException
from openmmtorch import TorchForce
import os
import tempfile
class ForceModule(torch.nn.Module):
"""A simple module that can be serialized"""
def forward(self, positions):
return torch.sum(positions**2)
class ForceModule2(torch.nn.Module):
"""A dummy module distict from ForceModule"""
def forward(self, positions):
return torch.sum(positions**3)
def createAndSerialize(model_filename, serialized_filename):
module = torch.jit.script(ForceModule())
module.save(model_filename)
torch_force = TorchForce(model_filename)
stored = XmlSerializer.serialize(torch_force)
with open(serialized_filename, 'w') as f:
f.write(stored)
def readXML(filename):
with open(filename, 'r') as f:
fileContents = f.read()
return fileContents
def deserialize(filename):
other_force = XmlSerializer.deserialize(readXML(filename))
def test_serialize():
with tempfile.TemporaryDirectory() as tempdir:
model_filename = os.path.join(tempdir, 'model.pt')
serialized_filename = os.path.join(tempdir, 'stored.xml')
createAndSerialize(model_filename, serialized_filename)
def test_deserialize():
with tempfile.TemporaryDirectory() as tempdir:
model_filename = os.path.join(tempdir, 'model.pt')
serialized_filename = os.path.join(tempdir, 'stored.xml')
createAndSerialize(model_filename, serialized_filename)
deserialize(serialized_filename)