Files
openmm/tests/TestParser.cpp
Peter Eastman aafb8b5b0c Vectorize custom expressions on CPU (#3552)
* Began implementing vectorization of Lepton expressions

* Tests for vector expressions

* Implemented CompiledVectorExpression for x86

* Bug fix

* Optimized select() on ARM

* Optimized select() on x86

* CompiledVectorExpression supports AVX

* Bug fix

* Updated docs

* Use VEX encoded instructions for CompiledExpression

* Optimized min() and max() on x86

* Optimized min() and max() on ARM

* Fixed compilation error

* Upgrade AsmJit
2022-04-10 13:51:38 -07:00

388 lines
16 KiB
C++

#include "../libraries/lepton/include/Lepton.h"
#include "openmm/internal/AssertionUtilities.h"
#include "lepton/CompiledVectorExpression.h"
#include <iostream>
#include <limits>
#include <map>
using namespace Lepton;
using namespace OpenMM;
using namespace std;
/**
* This is a custom function equal to f(x,y) = 2*x*y.
*/
class ExampleFunction : public CustomFunction {
int getNumArguments() const {
return 2;
}
double evaluate(const double* arguments) const {
return 2.0*arguments[0]*arguments[1];
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
if (derivOrder[0] == 1) {
if (derivOrder[1] == 0)
return 2.0*arguments[1];
else if (derivOrder[1] == 1)
return 2.0;
}
if (derivOrder[1] == 1 && derivOrder[0] == 0)
return 2.0*arguments[0];
return 0.0;
}
CustomFunction* clone() const {
return new ExampleFunction();
}
};
/**
* Verify that an expression gives the correct value.
*/
void verifyEvaluation(const string& expression, double expectedValue) {
map<string, CustomFunction*> customFunctions;
ParsedExpression parsed = Parser::parse(expression, customFunctions);
double value = parsed.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Try optimizing it and make sure the result is still correct.
value = parsed.optimize().evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Create an ExpressionProgram and see if that also gives the same result.
ExpressionProgram program = parsed.createProgram();
value = program.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Create a CompiledExpression and see if that also gives the same result.
CompiledExpression compiled = parsed.createCompiledExpression();
value = compiled.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
}
/**
* Verify that an expression with variables gives the correct value.
*/
void verifyEvaluation(const string& expression, double x, double y, double expectedValue) {
map<string, double> variables;
variables["x"] = x;
variables["y"] = y;
ParsedExpression parsed = Parser::parse(expression);
double value = parsed.evaluate(variables);
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Try optimizing it and make sure the result is still correct.
value = parsed.optimize().evaluate(variables);
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Try optimizing with predefined values for the variables.
value = parsed.optimize(variables).evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Create an ExpressionProgram and see if that also gives the same result.
ExpressionProgram program = parsed.createProgram();
value = program.evaluate(variables);
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Create a CompiledExpression and see if that also gives the same result.
CompiledExpression compiled = parsed.createCompiledExpression();
if (compiled.getVariables().find("x") != compiled.getVariables().end())
compiled.getVariableReference("x") = x;
if (compiled.getVariables().find("y") != compiled.getVariables().end())
compiled.getVariableReference("y") = y;
value = compiled.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Try specifying memory locations for the compiled expression.
map<string, double*> variablePointers;
variablePointers["x"] = &x;
variablePointers["y"] = &y;
CompiledExpression compiled2 = parsed.createCompiledExpression();
compiled2.setVariableLocations(variablePointers);
value = compiled2.evaluate();
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
ASSERT_EQUAL(&x, &compiled2.getVariableReference("x"));
ASSERT_EQUAL(&y, &compiled2.getVariableReference("y"));
// Try evaluating it as a vector.
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector = parsed.createCompiledVectorExpression(width);
for (int i = 0; i < width; i++) {
if (vector.getVariables().find("x") != vector.getVariables().end())
for (int j = 0; j < width; j++)
vector.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector.getVariables().find("y") != vector.getVariables().end())
for (int j = 0; j < width; j++)
vector.getVariablePointer("y")[j] = (i == j ? y : -100.0);
const float* result = vector.evaluate();
ASSERT_EQUAL_TOL(expectedValue, result[i], 1e-6);
}
}
// Specify memory locations for the vector expression.
float xvec[8], yvec[8];
map<string, float*> vecVariablePointers;
vecVariablePointers["x"] = xvec;
vecVariablePointers["y"] = yvec;
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector2 = parsed.createCompiledVectorExpression(width);
vector2.setVariableLocations(vecVariablePointers);
for (int i = 0; i < width; i++) {
for (int j = 0; j < width; j++) {
xvec[j] = (i == j ? x : -100.0);
yvec[j] = (i == j ? y : -100.0);
}
const float* result = vector2.evaluate();
ASSERT_EQUAL_TOL(expectedValue, result[i], 1e-6);
}
}
// Make sure that variable renaming works.
variables.clear();
variables["w"] = x;
variables["y"] = y;
map<string, string> replacements;
replacements["x"] = "w";
value = parsed.renameVariables(replacements).evaluate(variables);
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
}
/**
* Confirm that a parse error gets thrown.
*/
void verifyInvalidExpression(const string& expression) {
try {
Parser::parse(expression);
}
catch (const exception& ex) {
return;
}
throw exception();
}
/**
* Verify that two numbers have the same value.
*/
void assertNumbersEqual(double val1, double val2, double tol=1e-10) {
const double inf = numeric_limits<double>::infinity();
if (val1 == val1 || val2 == val2) // If both are NaN, that's fine.
if (val1 != inf || val2 != inf) // Both infinity is also fine.
if (val1 != -inf || val2 != -inf) // Same for -infinity.
ASSERT_EQUAL_TOL(val1, val2, tol);
}
/**
* Verify that two expressions give the same value.
*/
void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2, double x, double y) {
map<string, double> variables;
variables["x"] = x;
variables["y"] = y;
double val1 = exp1.evaluate(variables);
double val2 = exp2.evaluate(variables);
assertNumbersEqual(val1, val2);
// Now create CompiledExpressions from them and see if those also match.
CompiledExpression compiled1 = exp1.createCompiledExpression();
CompiledExpression compiled2 = exp2.createCompiledExpression();
if (compiled1.getVariables().find("x") != compiled1.getVariables().end())
compiled1.getVariableReference("x") = x;
if (compiled1.getVariables().find("y") != compiled1.getVariables().end())
compiled1.getVariableReference("y") = y;
if (compiled2.getVariables().find("x") != compiled2.getVariables().end())
compiled2.getVariableReference("x") = x;
if (compiled2.getVariables().find("y") != compiled2.getVariables().end())
compiled2.getVariableReference("y") = y;
assertNumbersEqual(val1, compiled1.evaluate());
assertNumbersEqual(val2, compiled2.evaluate());
// Now check CompiledVectorizedExpressions.
for (int width : CompiledVectorExpression::getAllowedWidths()) {
CompiledVectorExpression vector1 = exp1.createCompiledVectorExpression(width);
CompiledVectorExpression vector2 = exp2.createCompiledVectorExpression(width);
for (int i = 0; i < width; i++) {
if (vector1.getVariables().find("x") != vector1.getVariables().end())
for (int j = 0; j < width; j++)
vector1.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector1.getVariables().find("y") != vector1.getVariables().end())
for (int j = 0; j < width; j++)
vector1.getVariablePointer("y")[j] = (i == j ? y : -100.0);
if (vector2.getVariables().find("x") != vector2.getVariables().end())
for (int j = 0; j < width; j++)
vector2.getVariablePointer("x")[j] = (i == j ? x : -100.0);
if (vector2.getVariables().find("y") != vector2.getVariables().end())
for (int j = 0; j < width; j++)
vector2.getVariablePointer("y")[j] = (i == j ? y : -100.0);
const float* result1 = vector1.evaluate();
const float* result2 = vector2.evaluate();
assertNumbersEqual(val1, result1[i], 1e-6);
assertNumbersEqual(val2, result2[i], 1e-6);
}
}
}
/**
* Verify that the derivative of an expression is calculated correctly.
*/
void verifyDerivative(const string& expression, const string& expectedDeriv) {
ParsedExpression computed = Parser::parse(expression).differentiate("x").optimize();
ParsedExpression expected = Parser::parse(expectedDeriv);
verifySameValue(computed, expected, 1.0, 2.0);
verifySameValue(computed, expected, 2.0, 3.0);
verifySameValue(computed, expected, -2.0, 3.0);
verifySameValue(computed, expected, 2.0, -3.0);
verifySameValue(computed, expected, 0.0, -3.0);
verifySameValue(computed, expected, 2.0, 0.0);
}
/**
* Test the use of a custom function.
*/
void testCustomFunction(const string& expression, const string& equivalent) {
map<string, CustomFunction*> functions;
ExampleFunction exp;
functions["custom"] = &exp;
ParsedExpression exp1 = Parser::parse(expression, functions);
ParsedExpression exp2 = Parser::parse(equivalent);
verifySameValue(exp1, exp2, 1.0, 2.0);
verifySameValue(exp1, exp2, 2.0, 3.0);
verifySameValue(exp1, exp2, -2.0, 3.0);
verifySameValue(exp1, exp2, 2.0, -3.0);
ParsedExpression deriv1 = exp1.differentiate("x").optimize();
ParsedExpression deriv2 = exp2.differentiate("x").optimize();
verifySameValue(deriv1, deriv2, 1.0, 2.0);
verifySameValue(deriv1, deriv2, 2.0, 3.0);
verifySameValue(deriv1, deriv2, -2.0, 3.0);
verifySameValue(deriv1, deriv2, 2.0, -3.0);
ParsedExpression deriv3 = deriv1.differentiate("y").optimize();
ParsedExpression deriv4 = deriv2.differentiate("y").optimize();
verifySameValue(deriv3, deriv4, 1.0, 2.0);
verifySameValue(deriv3, deriv4, 2.0, 3.0);
verifySameValue(deriv3, deriv4, -2.0, 3.0);
verifySameValue(deriv3, deriv4, 2.0, -3.0);
}
int main() {
try {
verifyEvaluation("5", 5.0);
verifyEvaluation("5*2", 10.0);
verifyEvaluation("2*3+4*5", 26.0);
verifyEvaluation("2^-3", 0.125);
verifyEvaluation("1e+2", 100.0);
verifyEvaluation("-x", 2.0, 3.0, -2.0);
verifyEvaluation("y^-x", 3.0, 2.0, 0.125);
verifyEvaluation("1/-x", 3.0, 2.0, -1.0/3.0);
verifyEvaluation("2.1e-4*x*(y+1)", 3.0, 1.0, 1.26e-3);
verifyEvaluation("sin(2.5)", std::sin(2.5));
verifyEvaluation("cot(x)", 3.0, 1.0, 1.0/std::tan(3.0));
verifyEvaluation("log(x)", 3.0, 1.0, std::log(3.0));
verifyEvaluation("x^2+y^3+x^-1+y^(1/2)", 1.0, 1.0, 4.0);
verifyEvaluation("(2*x)*3", 4.0, 4.0, 24.0);
verifyEvaluation("(x*2)*3", 4.0, 4.0, 24.0);
verifyEvaluation("2*(x*3)", 4.0, 4.0, 24.0);
verifyEvaluation("2*(3*x)", 4.0, 4.0, 24.0);
verifyEvaluation("2*x/3", 1.0, 4.0, 2.0/3.0);
verifyEvaluation("x*2/3", 1.0, 4.0, 2.0/3.0);
verifyEvaluation("5*(-x)*(-y)", 1.0, 4.0, 20.0);
verifyEvaluation("5*(-x)*(y)", 1.0, 4.0, -20.0);
verifyEvaluation("5*(x)*(-y)", 1.0, 4.0, -20.0);
verifyEvaluation("5*(-x)/(-y)", 1.0, 4.0, 1.25);
verifyEvaluation("5*(-x)/(y)", 1.0, 4.0, -1.25);
verifyEvaluation("5*(x)/(-y)", 1.0, 4.0, -1.25);
verifyEvaluation("x+(-y)", 1.0, 4.0, -3.0);
verifyEvaluation("(-x)+y", 1.0, 4.0, 3.0);
verifyEvaluation("x/(1/y)", 1.0, 4.0, 4.0);
verifyEvaluation("x*w; w = 5", 3.0, 1.0, 15.0);
verifyEvaluation("a+b^2;a=x-b;b=3*y", 2.0, 3.0, 74.0);
verifyEvaluation("erf(x)+erfc(x)", 2.0, 3.0, 1.0);
verifyEvaluation("min(3, x)", 2.0, 3.0, 2.0);
verifyEvaluation("min(y, 5)", 2.0, 3.0, 3.0);
verifyEvaluation("max(x, y)", 2.0, 3.0, 3.0);
verifyEvaluation("max(x, -1)", 2.0, 3.0, 2.0);
verifyEvaluation("abs(x-y)", 2.0, 3.0, 1.0);
verifyEvaluation("delta(x)+3*delta(y-1.5)", 2.0, 1.5, 3.0);
verifyEvaluation("step(x-3)+y*step(x)", 2.0, 3.0, 3.0);
verifyEvaluation("floor(x)", -2.1, 3.0, -3.0);
verifyEvaluation("ceil(x)", -2.1, 3.0, -2.0);
verifyEvaluation("select(x, 1.0, y)", 0.3, 2.0, 1.0);
verifyEvaluation("select(x, 1.0, y)", 0.0, 2.0, 2.0);
verifyEvaluation("atan2(x, y)", 3.0, 1.5, std::atan(2.0));
verifyEvaluation("sqrt(x^2)", -2.2, 0.0, 2.2);
verifyEvaluation("sqrt(x)^2", 2.2, 0.0, 2.2);
verifyEvaluation("x^2+x^4", 2.0, 0.0, 20.0);
verifyEvaluation("x^-2+x^-3", 2.0, 0.0, 0.375);
verifyEvaluation("x^1.8", 2.2, 0.0, std::pow(2.2, 1.8));
verifyInvalidExpression("1..2");
verifyInvalidExpression("1*(2+3");
verifyInvalidExpression("5++4");
verifyInvalidExpression("1+2)");
verifyInvalidExpression("cos(2,3)");
verifyDerivative("x", "1");
verifyDerivative("x^2+x", "2*x+1");
verifyDerivative("y^x-x", "log(y)*(y^x)-1");
verifyDerivative("sin(x)", "cos(x)");
verifyDerivative("cos(x)", "-sin(x)");
verifyDerivative("tan(x)", "square(sec(x))");
verifyDerivative("cot(x)", "-square(csc(x))");
verifyDerivative("sec(x)", "sec(x)*tan(x)");
verifyDerivative("csc(x)", "-csc(x)*cot(x)");
verifyDerivative("exp(2*x)", "2*exp(2*x)");
verifyDerivative("log(x)", "1/x");
verifyDerivative("sqrt(x)", "0.5/sqrt(x)");
verifyDerivative("asin(x)", "1/sqrt(1-x^2)");
verifyDerivative("acos(x)", "-1/sqrt(1-x^2)");
verifyDerivative("atan(x)", "1/(1+x^2)");
verifyDerivative("atan2(2*x,y)", "2*y/(4*x^2+y^2)");
verifyDerivative("sinh(x)", "cosh(x)");
verifyDerivative("cosh(x)", "sinh(x)");
verifyDerivative("tanh(x)", "1/(cosh(x)^2)");
verifyDerivative("erf(x)", "1.12837916709551*exp(-x^2)");
verifyDerivative("erfc(x)", "-1.12837916709551*exp(-x^2)");
verifyDerivative("step(x)*x+step(1-x)*2*x", "step(x)+step(1-x)*2");
verifyDerivative("recip(x)", "-1/x^2");
verifyDerivative("square(x)", "2*x");
verifyDerivative("cube(x)", "3*x^2");
verifyDerivative("min(x, 2*x)", "step(x-2*x)*2+(1-step(x-2*x))*1");
verifyDerivative("max(5, x^2)", "(1-step(5-x^2))*2*x");
verifyDerivative("abs(3*x)", "step(3*x)*3+(1-step(3*x))*-3");
verifyDerivative("floor(x)+0.5*x*ceil(x)", "0.5*ceil(x)");
verifyDerivative("select(x, x^2, 3*x)", "select(x, 2*x, 3)");
testCustomFunction("custom(x, y)/2", "x*y");
testCustomFunction("custom(x^2, 1)+custom(2, y-1)", "2*x^2+4*(y-1)");
cout << Parser::parse("x*x").optimize() << endl;
cout << Parser::parse("x*(x*x)").optimize() << endl;
cout << Parser::parse("(x*x)*x").optimize() << endl;
cout << Parser::parse("2*3*x").optimize() << endl;
cout << Parser::parse("1/(1+x)").optimize() << endl;
cout << Parser::parse("x^(1/2)").optimize() << endl;
cout << Parser::parse("log(3*cos(x))^(sqrt(4)-2)").optimize() << endl;
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}