Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to build models inside R #172

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,6 @@ jobs:
path: ./test_models/
key: ${{ hashFiles('**/*.stan', 'src/*', 'stan/src/stan/version.hpp', 'Makefile') }}-${{ matrix.os }}-v${{ env.CACHE_VERSION }}

# needed for R tests until they have compilation utilities and can set this themselves.
- name: Set up TBB
if: matrix.os == 'windows-latest'
run: |
Add-Content $env:GITHUB_PATH "$(pwd)/stan/lib/stan_math/lib/tbb"

- name: Run tests
if: matrix.os != 'windows-latest'
run: |
Expand All @@ -231,6 +225,8 @@ jobs:
Rscript -e "devtools::test(reporter = c(\"summary\", \"fail\"))"
Rscript -e "install.packages(getwd(), repos=NULL, type=\"source\")"
Rscript example.R
env:
BRIDGESTAN: ${{ github.workspace }}

- name: Run tests (windows)
if: matrix.os == 'windows-latest'
Expand All @@ -241,6 +237,8 @@ jobs:
Rscript -e 'devtools::test(reporter = c(\"summary\", \"fail\"))'
Rscript -e 'install.packages(getwd(), repos=NULL, type=\"source\")'
Rscript example.R
env:
BRIDGESTAN: ${{ github.workspace }}

rust:
needs: [build]
Expand Down
2 changes: 2 additions & 0 deletions R/NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Generated by roxygen2: do not edit by hand

export(StanModel)
export(compile_model)
export(set_bridgestan_path)
16 changes: 12 additions & 4 deletions R/R/bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@ StanModel <- R6::R6Class("StanModel",
public = list(
#' @description
#' Create a Stan Model instance.
#' @param lib A path to a compiled BridgeStan Shared Object file.
#' @param lib A path to a compiled BridgeStan Shared Object file or a .stan file (will be compiled).
#' @param data Either a JSON string literal, a path to a data file in JSON format ending in ".json", or the empty string.
#' @param seed Seed for the RNG used in constructing the model.
#' @param stanc_args A list of arguments to pass to stanc3 if the model is not already compiled.
#' @param make_args A list of additional arguments to pass to Make if the model is not already compiled.
#' @return A new StanModel.
initialize = function(lib, data, seed) {
initialize = function(lib, data, seed, stanc_args = NULL, make_args = NULL) {
if (tools::file_ext(lib) == "stan") {
lib <- compile_model(lib, stanc_args, make_args)
}

if (.Platform$OS.type == "windows"){
windows_path_setup()
lib_old <- lib
lib <- paste0(tools::file_path_sans_ext(lib), ".dll")
file.copy(from=lib_old, to=lib)
Expand Down Expand Up @@ -75,7 +82,8 @@ StanModel <- R6::R6Class("StanModel",
PACKAGE = private$lib_name
)$info_out
},

#' @description
#' Get the version of BridgeStan used in the compiled model.
model_version= function() {
.C("bs_version_R",
major = as.integer(0),
Expand Down Expand Up @@ -345,7 +353,7 @@ handle_error <- function(lib_name, err_msg, err_ptr, function_name) {
#' StanRNG
#'
#' RNG object for use with `StanModel$param_constrain()`
#' @field rng The pointer to the RNG object.
#' @field ptr The pointer to the RNG object.
#' @keywords internal
StanRNG <- R6::R6Class("StanRNG",
public = list(
Expand Down
110 changes: 110 additions & 0 deletions R/R/compile.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
IS_WINDOWS <- isTRUE(.Platform$OS.type == "windows")
MAKE <- Sys.getenv("MAKE", ifelse(IS_WINDOWS, "mingw32-make", "make"))


#' Get the path to BridgeStan.
#'
#' By default this is set to the value of the environment
#' variable `BRIDGESTAN`.
#'
#' If there is no path set, this function will download
#' a matching version of BridgeStan to a folder called
#' `.bridgestan` in the user's home directory.
#'
#' See also `set_bridgestan_path`
verify_bridgestan_path <- function(path) {
suppressWarnings({
folder <- normalizePath(path)
})
if (!dir.exists(folder)) {
stop(paste0("BridgeStan folder '", folder, "' does not exist!\n", "If you need to set a different location, call 'set_bridgestan_path()'"))
}
makefile <- file.path(folder, "Makefile")
if (!file.exists(makefile)) {
stop(paste0("BridgeStan folder '", folder, "' does not contain file 'Makefile',",
" please ensure it is built properly!\n", "If you need to set a different location, call 'set_bridgestan_path()'"))
}
}

#' Set the path to BridgeStan.
#'
#' This should point to the top-level folder of the repository.
#' @export
set_bridgestan_path <- function(path) {
verify_bridgestan_path(path)
Sys.setenv(BRIDGESTAN = normalizePath(path))
}

get_bridgestan_path <- function() {
# try to get from environment
path <- Sys.getenv("BRIDGESTAN", unset = "")
if (path == "") {
path <- CURRENT_BRIDGESTAN
tryCatch({
verify_bridgestan_path(path)
}, error = function(e) {
print(paste0("Bridgestan not found at location specified by $BRIDGESTAN ",
"environment variable, downloading version ", packageVersion("bridgestan"),
" to ", path))
get_bridgestan_src()
})
}

return(path)
}


#' Run BridgeStan's Makefile on a `.stan` file, creating the `.so`
#' used by the StanModel class.
#' This function assumes that the path to BridgeStan is valid.
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
#' This can be set with `set_bridgestan_path`.
#'
#' @param stan_file A path to a Stan model file.
#' @param stanc_arg A list of arguments to pass to stanc3.
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
#' For example, `c('--O1')` will enable compiler optimization level 1.
#' @param make_args A list of additional arguments to pass to Make.
#' For example, `c('STAN_THREADS=True')` will enable
#' threading for the compiled model. If the same flags are defined
#' in `make/local`, the versions passed here will take precedent.
#' @return Path to the compiled model.
#' @export
compile_model <- function(stan_file, stanc_args = NULL, make_args = NULL) {
verify_bridgestan_path(get_bridgestan_path())
suppressWarnings({
file_path <- normalizePath(stan_file)
})
if (tools::file_ext(file_path) != "stan") {
stop(paste0("File '", file_path, "' does not end with '.stan'"))
}
if (!file.exists(file_path)) {
stop(paste0("File '", file_path, "' does not exist!"))
}

output <- paste0(tools::file_path_sans_ext(file_path), "_model.so")
stancflags <- paste("--include-paths=.", paste(stanc_args, collapse = " "))

flags <- c(paste("-C", get_bridgestan_path()), make_args, paste0("STANCFLAGS=\"",
stancflags, "\""), output)

suppressWarnings({
res <- system2(MAKE, args = flags, stdout = TRUE, stderr = TRUE)
})
res_attrs <- attributes(res)
if ("status" %in% names(res_attrs) && res_attrs$status != 0) {
stop(paste0("Compilation failed with error code ", res_attrs$status, "\noutput:\n",
paste(res, collapse = "\n")))
}

return(output)
}

windows_path_setup <- function() {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it. Where is this function used?

Copy link
Collaborator Author

@WardBrian WardBrian Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called in the constructor before we load the shared object (same as Python, looks like Julia calls it at startup which is not ideal and I will take a look at with #187)

if (.Platform$OS.type == "windows") {
suppressWarnings(out <- system2("where.exe", "tbb.dll", stdout = NULL, stderr = NULL))
if (out != 0) {
tbb_path <- file.path(get_bridgestan_path(), "stan", "lib", "stan_math",
"lib", "tbb")
Sys.setenv(PATH = paste(tbb_path, Sys.getenv("PATH"), sep = ";"))
}
}
}
35 changes: 35 additions & 0 deletions R/R/download.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
current_version <- packageVersion("bridgestan")
HOME_BRIDGESTAN <- path.expand(file.path("~", ".bridgestan"))
CURRENT_BRIDGESTAN <- file.path(HOME_BRIDGESTAN, paste0("bridgestan-", current_version))

RETRIES <- 5

get_bridgestan_src <- function() {
url <- paste0("https://github.com/roualdes/bridgestan/releases/download/", "v",
current_version, "/bridgestan-", current_version, ".tar.gz")

dir.create(HOME_BRIDGESTAN, showWarnings = FALSE, recursive = TRUE)
temp <- tempfile()
err_text <- paste("Failed to download Bridgestan", current_version, "from github.com.")
for (i in 1:RETRIES) {
tryCatch({
download.file(url, destfile = temp, mode = "wb", quiet = TRUE, method = "auto")
}, error = function(e) {
cat(err_text, "\n")
if (i == RETRIES) {
stop(err_text, call. = FALSE)
} else {
cat("Retrying (", i + 1, "/", RETRIES, ")...\n", sep = "")
Sys.sleep(1)
}
})
}

tryCatch({
untar(temp, exdir = HOME_BRIDGESTAN)
}, error = function(e) {
stop(paste("Failed to unpack", url, "during installation"), call. = FALSE)
})

unlink(temp)
}
9 changes: 7 additions & 2 deletions R/man/StanModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 1 addition & 8 deletions R/man/StanRNG.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions R/man/compile_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions R/man/set_bridgestan_path.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions R/man/verify_bridgestan_path.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 0 additions & 8 deletions R/tests/testthat.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
# This file is part of the standard setup for testthat.
# It is recommended that you do not modify it.
#
# Where should you do additional test configuration?
# Learn more about the roles of various files in:
# * https://r-pkgs.org/tests.html
# * https://testthat.r-lib.org/reference/test_package.html#special-files

library(testthat)
library(bridgestan)

Expand Down
15 changes: 15 additions & 0 deletions R/tests/testthat/setup.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
base = "../../.."

load_model <- function(name, include_data = TRUE) {
if (include_data) {
data = file.path(base, "test_models", name, paste0(name, ".data.json"))
} else {
data = ""
}
model <- StanModel$new(file.path(base, "test_models", name, paste0(name, "_model.so")),
data, 1234)
return(model)
}

simple <- load_model("simple")
bernoulli <- load_model("bernoulli")
13 changes: 13 additions & 0 deletions R/tests/testthat/test_collisions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

test_that("loading another library didn't break prior ones", {
if (.Platform$OS.type == "windows") {
dll = "./test_collisions.dll"
} else {
dll = "./test_collisions.so"
}
if (file.exists(dll)) {
dyn.load(dll)
expect_equal(bernoulli$name(), "bernoulli_model")
expect_equal(simple$name(), "simple_model")
}
})
Loading