mirror of
https://github.com/openmm/openmm
synced 2026-06-03 06:39:48 +09:00
74 lines
2.9 KiB
Python
74 lines
2.9 KiB
Python
import os
|
|
import unittest
|
|
import tempfile
|
|
from io import BytesIO, StringIO
|
|
from openmm import app
|
|
import openmm as mm
|
|
from openmm import unit
|
|
|
|
|
|
class TestCheckpointReporter(unittest.TestCase):
|
|
def setUp(self):
|
|
with open('systems/alanine-dipeptide-implicit.pdb') as f:
|
|
pdb = app.PDBFile(f)
|
|
forcefield = app.ForceField('amber99sbildn.xml')
|
|
system = forcefield.createSystem(pdb.topology,
|
|
nonbondedMethod=app.CutoffNonPeriodic, nonbondedCutoff=1.0*unit.nanometers,
|
|
constraints=app.HBonds)
|
|
self.simulation = app.Simulation(pdb.topology, system, mm.VerletIntegrator(0.002*unit.picoseconds))
|
|
self.simulation.context.setPositions(pdb.positions)
|
|
|
|
def test_1(self):
|
|
"""Test CheckpointReporter."""
|
|
for writeState in [True, False]:
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
filename = os.path.join(tempdir, 'checkpoint')
|
|
self.simulation.reporters.clear()
|
|
self.simulation.reporters.append(app.CheckpointReporter(filename, 1, writeState=writeState))
|
|
self.simulation.step(1)
|
|
|
|
# get the current positions
|
|
positions = self.simulation.context.getState(getPositions=True).getPositions()
|
|
|
|
# now set the positions into junk...
|
|
self.simulation.context.setPositions([mm.Vec3(0, 0, 0)] * len(positions))
|
|
|
|
# then reload the right positions from the checkpoint
|
|
if writeState:
|
|
self.simulation.loadState(filename)
|
|
else:
|
|
self.simulation.loadCheckpoint(filename)
|
|
|
|
newPositions = self.simulation.context.getState(getPositions=True).getPositions()
|
|
self.assertSequenceEqual(positions, newPositions)
|
|
|
|
def testFileObj(self):
|
|
"""Test writing to a file object. This should truncate so that only the most recent frame is present in the output."""
|
|
|
|
# Test checkpoint saving.
|
|
|
|
checkpointBuffer = BytesIO()
|
|
self.simulation.reporters.clear()
|
|
self.simulation.reporters.append(app.CheckpointReporter(checkpointBuffer, 1, writeState=False))
|
|
self.simulation.step(5)
|
|
checkpointData = checkpointBuffer.getvalue()
|
|
|
|
checkpointBuffer = BytesIO()
|
|
self.simulation.saveCheckpoint(checkpointBuffer)
|
|
self.assertSequenceEqual(checkpointData, checkpointBuffer.getvalue())
|
|
|
|
# Test state saving.
|
|
|
|
stateBuffer = StringIO()
|
|
self.simulation.reporters.clear()
|
|
self.simulation.reporters.append(app.CheckpointReporter(stateBuffer, 1, writeState=True))
|
|
self.simulation.step(5)
|
|
stateData = stateBuffer.getvalue()
|
|
|
|
stateBuffer = StringIO()
|
|
self.simulation.saveState(stateBuffer)
|
|
self.assertSequenceEqual(stateData, stateBuffer.getvalue())
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|