Skip to content

Commit

Permalink
[RF] Implement the EvalBackend("codegen_no_grad") option
Browse files Browse the repository at this point in the history
This is done by adding a new constructor flag to the RooFuncWrapper
whether the gradient should be compiled or not.
  • Loading branch information
guitargeek committed Sep 22, 2023
1 parent ece77ec commit dae23db
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 23 deletions.
9 changes: 5 additions & 4 deletions roofit/roofitcore/inc/RooFuncWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ class RooSimultaneous;
class RooFuncWrapper final : public RooAbsReal {
public:
RooFuncWrapper(const char *name, const char *title, std::string const &funcBody, RooArgSet const &paramSet,
const RooAbsData *data = nullptr, RooSimultaneous const *simPdf = nullptr);
const RooAbsData *data, RooSimultaneous const *simPdf, bool createGradient);

RooFuncWrapper(const char *name, const char *title, RooAbsReal const &obj, RooArgSet const &normSet,
const RooAbsData *data = nullptr, RooSimultaneous const *simPdf = nullptr);
const RooAbsData *data, RooSimultaneous const *simPdf, bool createGradient);

RooFuncWrapper(const RooFuncWrapper &other, const char *name = nullptr);

TObject *clone(const char *newname) const override { return new RooFuncWrapper(*this, newname); }

double defaultErrorLevel() const override { return 0.5; }

bool hasGradient() const override { return true; }
bool hasGradient() const override { return _hasGradient; }
void gradient(double *out) const override;

void gradient(const double *x, double *g) const;
Expand All @@ -65,7 +65,7 @@ class RooFuncWrapper final : public RooAbsReal {
void loadParamsAndData(std::string funcName, RooAbsArg const *head, RooArgSet const &paramSet,
const RooAbsData *data, RooSimultaneous const *simPdf);

void declareAndDiffFunction(std::string funcName, std::string const &funcBody);
void declareAndDiffFunction(std::string funcName, std::string const &funcBody, bool createGradient);

void buildFuncAndGradFunctors();

Expand All @@ -81,6 +81,7 @@ class RooFuncWrapper final : public RooAbsReal {
RooListProxy _params;
Func _func;
Grad _grad;
bool _hasGradient = false;
mutable std::vector<double> _gradientVarBuffer;
std::vector<double> _observables;
std::map<RooFit::Detail::DataKey, ObsInfo> _obsInfos;
Expand Down
12 changes: 6 additions & 6 deletions roofit/roofitcore/inc/RooGlobalFunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,12 @@ RooCmdArg Parallelize(int nWorkers) ;
RooCmdArg ModularL(bool flag=false) ;
RooCmdArg TimingAnalysis(bool timingAnalysis) ;

RooCmdArg BatchMode(std::string const& batchMode="cpu");
// The const char * overload is necessary, otherwise the compiler will cast a
// C-Style string to a bool and choose the BatchMode(bool) overload if one
// calls for example BatchMode("off").
inline RooCmdArg BatchMode(const char * batchMode) { return BatchMode(std::string(batchMode)); }
inline RooCmdArg BatchMode(bool batchModeOn) { return BatchMode(batchModeOn ? "cpu" : "off"); }
//RooCmdArg BatchMode(std::string const& batchMode="cpu");
//// The const char * overload is necessary, otherwise the compiler will cast a
//// C-Style string to a bool and choose the BatchMode(bool) overload if one
//// calls for example BatchMode("off").
//inline RooCmdArg BatchMode(const char * batchMode) { return BatchMode(std::string(batchMode)); }
//inline RooCmdArg BatchMode(bool batchModeOn) { return BatchMode(batchModeOn ? "cpu" : "off"); }

RooCmdArg IntegrateBins(double precision);

Expand Down
6 changes: 4 additions & 2 deletions roofit/roofitcore/src/RooAbsPdf.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1221,11 +1221,13 @@ std::unique_ptr<RooAbsReal> RooAbsPdf::createNLLImpl(RooAbsData& data, const Roo

std::unique_ptr<RooAbsReal> nllWrapper;

if(evalBackend == RooFit::EvalBackend::Value::Codegen) {
if(evalBackend == RooFit::EvalBackend::Value::Codegen || evalBackend == RooFit::EvalBackend::Value::CodegenNoGrad) {
static int iFuncWrapper = 0;
std::string wrapperName = "nll_func_wrapper_" + std::to_string(iFuncWrapper++);
bool createGradient = evalBackend == RooFit::EvalBackend::Value::Codegen;
auto simPdf = dynamic_cast<RooSimultaneous const *>(pdfClone.get());
nllWrapper = std::make_unique<RooFuncWrapper>(wrapperName.c_str(), wrapperName.c_str(), *nll, normSet, &data,
dynamic_cast<RooSimultaneous const *>(pdfClone.get()));
simPdf, createGradient);
} else {
auto evaluator = std::make_unique<RooFit::Evaluator>(*nll, evalBackend == RooFit::EvalBackend::Value::Cuda);
nllWrapper = std::make_unique<RooEvaluatorWrapper>(*nll,
Expand Down
20 changes: 12 additions & 8 deletions roofit/roofitcore/src/RooFuncWrapper.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@
#include <TSystem.h>

RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, std::string const &funcBody,
RooArgSet const &paramSet, const RooAbsData *data /*=nullptr*/,
RooSimultaneous const *simPdf)
: RooAbsReal{name, title}, _params{"!params", "List of parameters", this}
RooArgSet const &paramSet, const RooAbsData *data, RooSimultaneous const *simPdf,
bool createGradient)
: RooAbsReal{name, title}, _params{"!params", "List of parameters", this}, _hasGradient{createGradient}
{
// Declare the function and create its derivative.
declareAndDiffFunction(name, funcBody);
declareAndDiffFunction(name, funcBody, createGradient);

// Load the parameters and observables.
loadParamsAndData(name, nullptr, paramSet, data, simPdf);
}

RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal const &obj, RooArgSet const &normSet,
const RooAbsData *data /*=nullptr*/, RooSimultaneous const *simPdf)
: RooAbsReal{name, title}, _params{"!params", "List of parameters", this}
const RooAbsData *data, RooSimultaneous const *simPdf, bool createGradient)
: RooAbsReal{name, title}, _params{"!params", "List of parameters", this}, _hasGradient{createGradient}
{
std::string func;

Expand All @@ -60,14 +60,15 @@ RooFuncWrapper::RooFuncWrapper(const char *name, const char *title, RooAbsReal c
func = buildCode(*pdf);

// Declare the function and create its derivative.
declareAndDiffFunction(name, func);
declareAndDiffFunction(name, func, createGradient);
}

RooFuncWrapper::RooFuncWrapper(const RooFuncWrapper &other, const char *name)
: RooAbsReal(other, name),
_params("!params", this, other._params),
_func(other._func),
_grad(other._grad),
_hasGradient(other._hasGradient),
_gradientVarBuffer(other._gradientVarBuffer),
_observables(other._observables)
{
Expand Down Expand Up @@ -119,7 +120,7 @@ void RooFuncWrapper::loadParamsAndData(std::string funcName, RooAbsArg const *he
}
}

void RooFuncWrapper::declareAndDiffFunction(std::string funcName, std::string const &funcBody)
void RooFuncWrapper::declareAndDiffFunction(std::string funcName, std::string const &funcBody, bool createGradient)
{
std::string gradName = funcName + "_grad_0";
std::string requestName = funcName + "_req";
Expand All @@ -139,6 +140,9 @@ void RooFuncWrapper::declareAndDiffFunction(std::string funcName, std::string co
}
_func = reinterpret_cast<Func>(gInterpreter->ProcessLine((funcName + ";").c_str()));

if (!createGradient)
return;

// Calculate gradient
gInterpreter->ProcessLine("#include <Math/CladDerivator.h>");
// disable clang-format for making the following code unreadable.
Expand Down
6 changes: 3 additions & 3 deletions roofit/roofitcore/test/testRooFuncWrapper.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ TEST(RooFuncWrapper, GaussianNormalizedHardcoded)
"const double sig = params[2];"
"double out = std::exp(-0.5 * arg * arg / (sig * sig));"
"return 1. / (std::sqrt(TMath::TwoPi()) * sig) * out;";
RooFuncWrapper gaussFunc("myGauss1", "myGauss1", func, {x, mu, sigma}, {});
RooFuncWrapper gaussFunc("myGauss1", "myGauss1", func, {x, mu, sigma}, nullptr, nullptr, true);

// Check if functions results are the same even after changing parameters.
EXPECT_NEAR(gauss.getVal(normSet), gaussFunc.getVal(), 1e-8);
Expand Down Expand Up @@ -149,7 +149,7 @@ TEST(RooFuncWrapper, GaussianNormalized)

RooArgSet normSet{x};

RooFuncWrapper gaussFunc("myGauss3", "myGauss3", gauss, normSet);
RooFuncWrapper gaussFunc("myGauss3", "myGauss3", gauss, normSet, nullptr, nullptr, true);

RooArgSet paramsGauss;
gauss.getParameters(nullptr, paramsGauss);
Expand Down Expand Up @@ -189,7 +189,7 @@ TEST(RooFuncWrapper, Exponential)

RooArgSet normSet{x};

RooFuncWrapper expoFunc(name.c_str(), name.c_str(), expo, normSet);
RooFuncWrapper expoFunc(name.c_str(), name.c_str(), expo, normSet, nullptr, nullptr, true);

RooArgSet params;
expo.getParameters(nullptr, params);
Expand Down

0 comments on commit dae23db

Please sign in to comment.