mirror of
https://github.com/openmm/openmm-torch.git
synced 2026-03-10 19:14:16 +09:00
* Add version number as a member to TorchForceProxy * Encode the model file contents when serializing TorchForce * Add tests for new TorchForce serialization * Fix test not finding Python executable * Format include directives correctly * Hardcode TorchForceProxy version number * Fix formatting issues * Move Python serialization test to the correct place * Make function encodeFromFileName static * Update serialization python test to correctly remove temporary files after executing * Use the base64 encoding capabilities of openssl to serialize model file * Update TorchForce serializer * Add a constructor to TorchForce that takes a torch::jit::Module. TorchForce(string fileName) is implemented by delegating to the new constructor. Update serialization test accordingly to compare the module file name and the module itself. * Remove unnecessary include * Change i_file to file in TorchForce constructor * Add swig typemaps to new TorchForce constructor * Add setup.py as a dependency for the PythonInstall CMake rule * Fix swig out typemap for torch::jit::Module Now it is possible to call getModule() on a TorchForce object from Python, which will return a module of the same type as, for instance, torch.jit.load() * Remove commented line in CMakeLists.txt * Remove unnecessary dependency in setup.py * Add more tests for new constructor * Add some comments for the new constructor * Updates to TorchForce serialization * Use hex encoding instead of base64 for serialization. SSL no longer a direct dependency. * Remove unnecessary header * Update Python serialization test * Minor changes * Improve temporary path handling in python serialization tests * More informative exception when failing to serialize TorchForce * Remove unnecessary check in TorchForce serialization * Changes to C++ serialization tests * Changes to C++ serialization tests
49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
import torch
|
|
import shutil
|
|
import pytest
|
|
from openmm import XmlSerializer, OpenMMException
|
|
from openmmtorch import TorchForce
|
|
import os
|
|
import tempfile
|
|
|
|
class ForceModule(torch.nn.Module):
|
|
"""A simple module that can be serialized"""
|
|
def forward(self, positions):
|
|
return torch.sum(positions**2)
|
|
|
|
|
|
class ForceModule2(torch.nn.Module):
|
|
"""A dummy module distict from ForceModule"""
|
|
def forward(self, positions):
|
|
return torch.sum(positions**3)
|
|
|
|
|
|
def createAndSerialize(model_filename, serialized_filename):
|
|
module = torch.jit.script(ForceModule())
|
|
module.save(model_filename)
|
|
torch_force = TorchForce(model_filename)
|
|
stored = XmlSerializer.serialize(torch_force)
|
|
with open(serialized_filename, 'w') as f:
|
|
f.write(stored)
|
|
|
|
def readXML(filename):
|
|
with open(filename, 'r') as f:
|
|
fileContents = f.read()
|
|
return fileContents
|
|
|
|
def deserialize(filename):
|
|
other_force = XmlSerializer.deserialize(readXML(filename))
|
|
|
|
def test_serialize():
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
model_filename = os.path.join(tempdir, 'model.pt')
|
|
serialized_filename = os.path.join(tempdir, 'stored.xml')
|
|
createAndSerialize(model_filename, serialized_filename)
|
|
|
|
def test_deserialize():
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
model_filename = os.path.join(tempdir, 'model.pt')
|
|
serialized_filename = os.path.join(tempdir, 'stored.xml')
|
|
createAndSerialize(model_filename, serialized_filename)
|
|
deserialize(serialized_filename)
|