mirror of
https://github.com/openmm/openmm-torch.git
synced 2026-03-10 19:14:16 +09:00
* Update Mac build to newer versions * Fixed filename * Debugging * Try a more tolerant check
130 lines
5.0 KiB
Python
130 lines
5.0 KiB
Python
import openmm as mm
|
|
import openmm.unit as unit
|
|
import openmmtorch as ot
|
|
import numpy as np
|
|
import pytest
|
|
import torch as pt
|
|
from tempfile import NamedTemporaryFile
|
|
|
|
@pytest.mark.parametrize('model_file,',
|
|
['../../tests/central.pt',
|
|
'../../tests/forces.pt'])
|
|
def testConstructors(model_file):
|
|
force = ot.TorchForce(model_file)
|
|
model = pt.jit.load(model_file)
|
|
force = ot.TorchForce(pt.jit.load(model_file))
|
|
model = force.getModule()
|
|
force = ot.TorchForce(model)
|
|
|
|
@pytest.mark.parametrize('model_file, output_forces, use_module_constructor',
|
|
[('../../tests/central.pt', False, False,),
|
|
('../../tests/forces.pt', True, False),
|
|
('../../tests/forces.pt', True, True)])
|
|
@pytest.mark.parametrize('use_cv_force', [True, False])
|
|
@pytest.mark.parametrize('platform', [mm.Platform.getPlatform(i).getName() for i in range(mm.Platform.getNumPlatforms())])
|
|
def testForce(model_file, output_forces, use_module_constructor, use_cv_force, platform):
|
|
|
|
if pt.cuda.device_count() < 1 and platform == 'CUDA':
|
|
pytest.skip('A CUDA device is not available')
|
|
|
|
# Create a random cloud of particles.
|
|
numParticles = 10
|
|
system = mm.System()
|
|
positions = np.random.rand(numParticles, 3)
|
|
for _ in range(numParticles):
|
|
system.addParticle(1.0)
|
|
|
|
# Create a force
|
|
if use_module_constructor:
|
|
model = pt.jit.load(model_file)
|
|
force = ot.TorchForce(model, {'useCUDAGraphs': 'false'})
|
|
else:
|
|
force = ot.TorchForce(model_file, {'useCUDAGraphs': 'false'})
|
|
assert not force.getOutputsForces() # Check the default
|
|
force.setOutputsForces(output_forces)
|
|
assert force.getOutputsForces() == output_forces
|
|
assert force.getProperties()['useCUDAGraphs'] == 'false'
|
|
if use_cv_force:
|
|
# Wrap TorchForce into CustomCVForce
|
|
cv_force = mm.CustomCVForce('force')
|
|
cv_force.addCollectiveVariable('force', force)
|
|
system.addForce(cv_force)
|
|
else:
|
|
system.addForce(force)
|
|
|
|
# Compute the forces and energy.
|
|
integ = mm.VerletIntegrator(1.0)
|
|
try:
|
|
context = mm.Context(system, integ, mm.Platform.getPlatformByName(platform))
|
|
except:
|
|
pytest.skip(f'Unable to create Context with {platform}')
|
|
context.setPositions(positions)
|
|
state = context.getState(getEnergy=True, getForces=True)
|
|
|
|
# See if the energy and forces are correct. The network defines a potential of the form E(r) = |r|^2
|
|
expectedEnergy = np.sum(positions*positions)
|
|
assert np.allclose(expectedEnergy, state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole))
|
|
assert np.allclose(-2*positions, state.getForces(asNumpy=True))
|
|
|
|
@pytest.mark.parametrize('deviceString', ['cpu', 'cuda:0', 'cuda:1'])
|
|
@pytest.mark.parametrize('precision', ['single', 'mixed', 'double'])
|
|
def testModuleArguments(deviceString, precision):
|
|
|
|
if pt.cuda.device_count() < 1 and deviceString == 'cuda:0':
|
|
pytest.skip('A CUDA device is not available')
|
|
if pt.cuda.device_count() < 2 and deviceString == 'cuda:1':
|
|
pytest.skip('Two CUDA devices are not available')
|
|
|
|
class TestModule(pt.nn.Module):
|
|
|
|
def __init__(self, device, dtype, positions):
|
|
super().__init__()
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.register_buffer('positions', pt.tensor(positions).to(dtype))
|
|
|
|
def forward(self, positions):
|
|
assert self.positions.device == self.device
|
|
assert positions.device == self.device
|
|
assert positions.dtype == self.dtype
|
|
assert pt.allclose(positions, self.positions)
|
|
return pt.sum(positions)
|
|
|
|
with NamedTemporaryFile() as fd:
|
|
|
|
numParticles = 10
|
|
system = mm.System()
|
|
positions = np.random.rand(numParticles, 3)
|
|
for _ in range(numParticles):
|
|
system.addParticle(1.0)
|
|
|
|
device = pt.device(deviceString)
|
|
if device.type == 'cpu' or precision == 'double':
|
|
dtype = pt.float64
|
|
else:
|
|
dtype = pt.float32
|
|
module = TestModule(device, dtype, positions)
|
|
pt.jit.script(module).save(fd.name)
|
|
force = ot.TorchForce(fd.name)
|
|
system.addForce(force)
|
|
|
|
integrator = mm.VerletIntegrator(1.0)
|
|
platform = mm.Platform.getPlatformByName(device.type.upper())
|
|
properties = {}
|
|
if device.type == 'cuda':
|
|
properties['DeviceIndex'] = str(device.index)
|
|
properties['Precision'] = precision
|
|
context = mm.Context(system, integrator, platform, properties)
|
|
|
|
context.setPositions(positions)
|
|
context.getState(getEnergy=True, getForces=True)
|
|
|
|
|
|
def testProperties():
|
|
""" Test that the properties are correctly set and retrieved """
|
|
force = ot.TorchForce('../../tests/central.pt')
|
|
force.setProperty('useCUDAGraphs', 'true')
|
|
assert force.getProperties()['useCUDAGraphs'] == 'true'
|
|
force.setProperty('useCUDAGraphs', 'false')
|
|
assert force.getProperties()['useCUDAGraphs'] == 'false'
|