Make sure contexts are deselected before evaluation (#5279)

* Context deselection before energy evaluation

* Check that the correct context is popped by popAsCurrent()
This commit is contained in:
Evan Pretti
2026-05-04 16:26:18 -07:00
committed by GitHub
parent 0aee805025
commit 14f8b06118
9 changed files with 67 additions and 18 deletions

View File

@@ -49,6 +49,25 @@ private:
ComputeContext& context;
};
/**
* This class deselects a ComputeContext by calling popAsCurrent() on the
* context when it is created and pushAsCurrent() when it goes out of scope.
* This can be useful to temporarily undo the effect of a ContextSelector and
* must only be used when the context is already selected.
*/
class OPENMM_EXPORT_COMMON ContextDeselector {
public:
ContextDeselector(ComputeContext& context) : context(context) {
context.popAsCurrent();
}
~ContextDeselector() {
context.pushAsCurrent();
}
private:
ComputeContext& context;
};
} // namespace OpenMM
#endif /*OPENMM_CONTEXTSELECTOR_H_*/

View File

@@ -683,7 +683,10 @@ void CommonIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
}
else {
recordChangedParameters(context);
{
ContextDeselector deselector(cc);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups);
}
savedEnergy[forceGroups] = energy;
if (needsEnergyParamDerivs) {
context.getEnergyParameterDerivatives(energyParamDerivs);

View File

@@ -96,17 +96,17 @@ void CommonIntegrateNoseHooverStepKernel::initialize(const System& system, const
}
void CommonIntegrateNoseHooverStepKernel::execute(ContextImpl& context, const NoseHooverIntegrator& integrator) {
// If the atom reordering has occured, the forces from the previous step are permuted and thus invalid.
// They need to be either sorted or recomputed; here we choose the latter.
if (cc.getAtomsWereReordered())
context.calcForcesAndEnergy(true, false, integrator.getIntegrationForceGroups());
ContextSelector selector(cc);
IntegrationUtilities& integration = cc.getIntegrationUtilities();
int paddedNumAtoms = cc.getPaddedNumAtoms();
double dt = integrator.getStepSize();
cc.getIntegrationUtilities().setNextStepSize(dt);
// If the atom reordering has occured, the forces from the previous step are permuted and thus invalid.
// They need to be either sorted or recomputed; here we choose the latter.
if (cc.getAtomsWereReordered())
context.calcForcesAndEnergy(true, false, integrator.getIntegrationForceGroups());
const auto& atomList = integrator.getAllThermostatedIndividualParticles();
const auto& pairList = integrator.getAllThermostatedPairs();
int numAtoms = atomList.size();

View File

@@ -4709,7 +4709,7 @@ void CommonCalcCustomCPPForceKernel::initialize(const ContextImpl& context, Cust
forceGroupFlag = (1<<force.getOwner().getForceGroup());
useWorkerThread = (cc.getNumContexts() == 1);
for (const ForceImpl* impl : context.getForceImpls())
if (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL)
if (impl != &force && (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL))
useWorkerThread = false;
if (useWorkerThread) {
cc.addPreComputation(new StartCalculationPreComputation(*this));
@@ -4871,7 +4871,7 @@ void CommonCalcPythonForceKernel::initialize(const ContextImpl& context, const P
forceGroupFlag = (1<<force.getForceGroup());
useWorkerThread = (cc.getNumContexts() == 1);
for (const ForceImpl* impl : context.getForceImpls())
if (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL)
if (&impl->getOwner() != &force && (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL))
useWorkerThread = false;
if (useWorkerThread) {
cc.addPreComputation(new StartCalculationPreComputation(*this));

View File

@@ -547,7 +547,11 @@ void CommonMinimizeKernel::evaluateGpu(ContextImpl& context) {
// Evaluate the forces and energy for the desired interactions as well as
// harmonic restraints to emulate the constraints.
{
ContextDeselector deselector(cc);
energy = context.calcForcesAndEnergy(true, true, forceGroups);
}
if (numConstraints) {
if (mixedIsDouble) {
getConstraintEnergyForcesKernel->setArg(8, kRestraint);
@@ -592,7 +596,11 @@ double CommonMinimizeKernel::evaluateCpu(ContextImpl& context) {
cpuContext->setState(context.getOwner().getState(State::Parameters));
cpuContext->setPositions(hostPositions);
cpuContext->computeVirtualSites();
State state = cpuContext->getState(State::Energy | State::Forces, false, forceGroups);
State state;
{
ContextDeselector deselector(cc);
state = cpuContext->getState(State::Energy | State::Forces, false, forceGroups);
}
double hostEnergy = state.getPotentialEnergy();
const vector<Vec3>& hostForces = state.getForces();
@@ -676,8 +684,11 @@ bool CommonMinimizeKernel::report(ContextImpl& context, int iteration) {
args["system energy"] = energy - restraintEnergy;
args["restraint strength"] = kRestraint;
args["max constraint error"] = maxError;
{
ContextDeselector deselector(cc);
return reporter->report(iteration - 1, hostX, hostGrad, args);
}
}
void CommonMinimizeKernel::downloadReturnFlagStart() {
downloadStartEvent->enqueue();

View File

@@ -461,8 +461,11 @@ void CudaContext::pushAsCurrent() {
void CudaContext::popAsCurrent() {
CUcontext popped;
if (contextIsValid)
if (contextIsValid) {
cuCtxPopCurrent(&popped);
if (popped != context)
throw OpenMMException("Called popAsCurrent() on a context that is not current");
}
}
CUmodule CudaContext::createModule(const string source, const char* optimizationFlags) {

View File

@@ -1399,9 +1399,11 @@ void CommonCalcAmoebaMultipoleForceKernel::ensureMultipolesValid(ContextImpl& co
}
}
}
if (!multipolesAreValid)
if (!multipolesAreValid) {
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(false, false, context.getIntegrator().getIntegrationForceGroups());
}
}
void CommonCalcAmoebaMultipoleForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
ContextSelector selector(cc);
@@ -3487,9 +3489,11 @@ void CommonCalcHippoNonbondedForceKernel::ensureMultipolesValid(ContextImpl& con
}
}
}
if (!multipolesAreValid)
if (!multipolesAreValid) {
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(false, false, context.getIntegrator().getIntegrationForceGroups());
}
}
void CommonCalcHippoNonbondedForceKernel::getLabFramePermanentDipoles(ContextImpl& context, vector<Vec3>& dipoles) {
ContextSelector selector(cc);

View File

@@ -502,7 +502,10 @@ void CommonIntegrateDrudeSCFStepKernel::minimize(ContextImpl& context, double to
int numDrude = drudeParams.getSize();
int paddedNumAtoms = cc.getPaddedNumAtoms();
for (int iteration = 0; iteration < 50; iteration++) {
{
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(true, false, context.getIntegrator().getIntegrationForceGroups());
}
minimizeKernel->execute(drudeParams.getSize());
cc.getLongForceBuffer().download(forces);
double totalForce = 0;

View File

@@ -297,7 +297,10 @@ void CommonIntegrateRPMDStepKernel::computeForces(ContextImpl& context) {
context.getPeriodicBoxVectors(finalBox[0], finalBox[1], finalBox[2]);
if (initialBox[0] != finalBox[0] || initialBox[1] != finalBox[1] || initialBox[2] != finalBox[2])
throw OpenMMException("Standard barostats cannot be used with RPMDIntegrator. Use RPMDMonteCarloBarostat instead.");
{
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(true, false, groupsNotContracted);
}
copyFromContextKernel->setArg(7, i);
copyFromContextKernel->execute(cc.getNumAtoms());
}
@@ -322,7 +325,10 @@ void CommonIntegrateRPMDStepKernel::computeForces(ContextImpl& context) {
copyToContextKernel->setArg(5, i);
copyToContextKernel->execute(cc.getNumAtoms());
context.computeVirtualSites();
{
ContextDeselector deselector(cc);
context.calcForcesAndEnergy(true, false, groupFlags);
}
copyFromContextKernel->setArg(7, i);
copyFromContextKernel->execute(cc.getNumAtoms());
}