mirror of
https://github.com/openmm/openmm
synced 2026-06-03 06:39:48 +09:00
Make sure contexts are deselected before evaluation (#5279)
* Context deselection before energy evaluation * Check that the correct context is popped by popAsCurrent()
This commit is contained in:
@@ -49,6 +49,25 @@ private:
|
||||
ComputeContext& context;
|
||||
};
|
||||
|
||||
/**
|
||||
* This class deselects a ComputeContext by calling popAsCurrent() on the
|
||||
* context when it is created and pushAsCurrent() when it goes out of scope.
|
||||
* This can be useful to temporarily undo the effect of a ContextSelector and
|
||||
* must only be used when the context is already selected.
|
||||
*/
|
||||
|
||||
class OPENMM_EXPORT_COMMON ContextDeselector {
|
||||
public:
|
||||
ContextDeselector(ComputeContext& context) : context(context) {
|
||||
context.popAsCurrent();
|
||||
}
|
||||
~ContextDeselector() {
|
||||
context.pushAsCurrent();
|
||||
}
|
||||
private:
|
||||
ComputeContext& context;
|
||||
};
|
||||
|
||||
} // namespace OpenMM
|
||||
|
||||
#endif /*OPENMM_CONTEXTSELECTOR_H_*/
|
||||
|
||||
@@ -683,7 +683,10 @@ void CommonIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
|
||||
}
|
||||
else {
|
||||
recordChangedParameters(context);
|
||||
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups);
|
||||
{
|
||||
ContextDeselector deselector(cc);
|
||||
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups);
|
||||
}
|
||||
savedEnergy[forceGroups] = energy;
|
||||
if (needsEnergyParamDerivs) {
|
||||
context.getEnergyParameterDerivatives(energyParamDerivs);
|
||||
|
||||
@@ -96,17 +96,17 @@ void CommonIntegrateNoseHooverStepKernel::initialize(const System& system, const
|
||||
}
|
||||
|
||||
void CommonIntegrateNoseHooverStepKernel::execute(ContextImpl& context, const NoseHooverIntegrator& integrator) {
|
||||
// If the atom reordering has occured, the forces from the previous step are permuted and thus invalid.
|
||||
// They need to be either sorted or recomputed; here we choose the latter.
|
||||
if (cc.getAtomsWereReordered())
|
||||
context.calcForcesAndEnergy(true, false, integrator.getIntegrationForceGroups());
|
||||
|
||||
ContextSelector selector(cc);
|
||||
IntegrationUtilities& integration = cc.getIntegrationUtilities();
|
||||
int paddedNumAtoms = cc.getPaddedNumAtoms();
|
||||
double dt = integrator.getStepSize();
|
||||
cc.getIntegrationUtilities().setNextStepSize(dt);
|
||||
|
||||
// If the atom reordering has occured, the forces from the previous step are permuted and thus invalid.
|
||||
// They need to be either sorted or recomputed; here we choose the latter.
|
||||
if (cc.getAtomsWereReordered())
|
||||
context.calcForcesAndEnergy(true, false, integrator.getIntegrationForceGroups());
|
||||
|
||||
const auto& atomList = integrator.getAllThermostatedIndividualParticles();
|
||||
const auto& pairList = integrator.getAllThermostatedPairs();
|
||||
int numAtoms = atomList.size();
|
||||
|
||||
@@ -4709,7 +4709,7 @@ void CommonCalcCustomCPPForceKernel::initialize(const ContextImpl& context, Cust
|
||||
forceGroupFlag = (1<<force.getOwner().getForceGroup());
|
||||
useWorkerThread = (cc.getNumContexts() == 1);
|
||||
for (const ForceImpl* impl : context.getForceImpls())
|
||||
if (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL)
|
||||
if (impl != &force && (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL))
|
||||
useWorkerThread = false;
|
||||
if (useWorkerThread) {
|
||||
cc.addPreComputation(new StartCalculationPreComputation(*this));
|
||||
@@ -4871,7 +4871,7 @@ void CommonCalcPythonForceKernel::initialize(const ContextImpl& context, const P
|
||||
forceGroupFlag = (1<<force.getForceGroup());
|
||||
useWorkerThread = (cc.getNumContexts() == 1);
|
||||
for (const ForceImpl* impl : context.getForceImpls())
|
||||
if (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL)
|
||||
if (&impl->getOwner() != &force && (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL))
|
||||
useWorkerThread = false;
|
||||
if (useWorkerThread) {
|
||||
cc.addPreComputation(new StartCalculationPreComputation(*this));
|
||||
|
||||
@@ -547,7 +547,11 @@ void CommonMinimizeKernel::evaluateGpu(ContextImpl& context) {
|
||||
// Evaluate the forces and energy for the desired interactions as well as
|
||||
// harmonic restraints to emulate the constraints.
|
||||
|
||||
energy = context.calcForcesAndEnergy(true, true, forceGroups);
|
||||
{
|
||||
ContextDeselector deselector(cc);
|
||||
energy = context.calcForcesAndEnergy(true, true, forceGroups);
|
||||
}
|
||||
|
||||
if (numConstraints) {
|
||||
if (mixedIsDouble) {
|
||||
getConstraintEnergyForcesKernel->setArg(8, kRestraint);
|
||||
@@ -592,7 +596,11 @@ double CommonMinimizeKernel::evaluateCpu(ContextImpl& context) {
|
||||
cpuContext->setState(context.getOwner().getState(State::Parameters));
|
||||
cpuContext->setPositions(hostPositions);
|
||||
cpuContext->computeVirtualSites();
|
||||
State state = cpuContext->getState(State::Energy | State::Forces, false, forceGroups);
|
||||
State state;
|
||||
{
|
||||
ContextDeselector deselector(cc);
|
||||
state = cpuContext->getState(State::Energy | State::Forces, false, forceGroups);
|
||||
}
|
||||
double hostEnergy = state.getPotentialEnergy();
|
||||
const vector<Vec3>& hostForces = state.getForces();
|
||||
|
||||
@@ -676,7 +684,10 @@ bool CommonMinimizeKernel::report(ContextImpl& context, int iteration) {
|
||||
args["system energy"] = energy - restraintEnergy;
|
||||
args["restraint strength"] = kRestraint;
|
||||
args["max constraint error"] = maxError;
|
||||
return reporter->report(iteration - 1, hostX, hostGrad, args);
|
||||
{
|
||||
ContextDeselector deselector(cc);
|
||||
return reporter->report(iteration - 1, hostX, hostGrad, args);
|
||||
}
|
||||
}
|
||||
|
||||
void CommonMinimizeKernel::downloadReturnFlagStart() {
|
||||
|
||||
@@ -461,8 +461,11 @@ void CudaContext::pushAsCurrent() {
|
||||
|
||||
void CudaContext::popAsCurrent() {
|
||||
CUcontext popped;
|
||||
if (contextIsValid)
|
||||
if (contextIsValid) {
|
||||
cuCtxPopCurrent(&popped);
|
||||
if (popped != context)
|
||||
throw OpenMMException("Called popAsCurrent() on a context that is not current");
|
||||
}
|
||||
}
|
||||
|
||||
CUmodule CudaContext::createModule(const string source, const char* optimizationFlags) {
|
||||
@@ -886,4 +889,4 @@ void CudaContext::ensureCudaInitialized() {
|
||||
CHECK_RESULT2(cuInit(0), "Error initializing CUDA");
|
||||
hasInitializedCuda = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1399,8 +1399,10 @@ void CommonCalcAmoebaMultipoleForceKernel::ensureMultipolesValid(ContextImpl& co
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!multipolesAreValid)
|
||||
if (!multipolesAreValid) {
|
||||
ContextDeselector deselector(cc);
|
||||
context.calcForcesAndEnergy(false, false, context.getIntegrator().getIntegrationForceGroups());
|
||||
}
|
||||
}
|
||||
|
||||
void CommonCalcAmoebaMultipoleForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
|
||||
@@ -3487,8 +3489,10 @@ void CommonCalcHippoNonbondedForceKernel::ensureMultipolesValid(ContextImpl& con
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!multipolesAreValid)
|
||||
if (!multipolesAreValid) {
|
||||
ContextDeselector deselector(cc);
|
||||
context.calcForcesAndEnergy(false, false, context.getIntegrator().getIntegrationForceGroups());
|
||||
}
|
||||
}
|
||||
|
||||
void CommonCalcHippoNonbondedForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
|
||||
|
||||
@@ -502,7 +502,10 @@ void CommonIntegrateDrudeSCFStepKernel::minimize(ContextImpl& context, double to
|
||||
int numDrude = drudeParams.getSize();
|
||||
int paddedNumAtoms = cc.getPaddedNumAtoms();
|
||||
for (int iteration = 0; iteration < 50; iteration++) {
|
||||
context.calcForcesAndEnergy(true, false, context.getIntegrator().getIntegrationForceGroups());
|
||||
{
|
||||
ContextDeselector deselector(cc);
|
||||
context.calcForcesAndEnergy(true, false, context.getIntegrator().getIntegrationForceGroups());
|
||||
}
|
||||
minimizeKernel->execute(drudeParams.getSize());
|
||||
cc.getLongForceBuffer().download(forces);
|
||||
double totalForce = 0;
|
||||
|
||||
@@ -297,7 +297,10 @@ void CommonIntegrateRPMDStepKernel::computeForces(ContextImpl& context) {
|
||||
context.getPeriodicBoxVectors(finalBox[0], finalBox[1], finalBox[2]);
|
||||
if (initialBox[0] != finalBox[0] || initialBox[1] != finalBox[1] || initialBox[2] != finalBox[2])
|
||||
throw OpenMMException("Standard barostats cannot be used with RPMDIntegrator. Use RPMDMonteCarloBarostat instead.");
|
||||
context.calcForcesAndEnergy(true, false, groupsNotContracted);
|
||||
{
|
||||
ContextDeselector deselector(cc);
|
||||
context.calcForcesAndEnergy(true, false, groupsNotContracted);
|
||||
}
|
||||
copyFromContextKernel->setArg(7, i);
|
||||
copyFromContextKernel->execute(cc.getNumAtoms());
|
||||
}
|
||||
@@ -322,7 +325,10 @@ void CommonIntegrateRPMDStepKernel::computeForces(ContextImpl& context) {
|
||||
copyToContextKernel->setArg(5, i);
|
||||
copyToContextKernel->execute(cc.getNumAtoms());
|
||||
context.computeVirtualSites();
|
||||
context.calcForcesAndEnergy(true, false, groupFlags);
|
||||
{
|
||||
ContextDeselector deselector(cc);
|
||||
context.calcForcesAndEnergy(true, false, groupFlags);
|
||||
}
|
||||
copyFromContextKernel->setArg(7, i);
|
||||
copyFromContextKernel->execute(cc.getNumAtoms());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user