-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add functionality to build models inside R #172
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(StanModel) | ||
export(compile_model) | ||
export(set_bridgestan_path) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
IS_WINDOWS <- isTRUE(.Platform$OS.type == "windows") | ||
MAKE <- Sys.getenv("MAKE", ifelse(IS_WINDOWS, "mingw32-make", "make")) | ||
|
||
|
||
#' Get the path to BridgeStan. | ||
#' | ||
#' By default this is set to the value of the environment | ||
#' variable `BRIDGESTAN`. | ||
#' | ||
#' If there is no path set, this function will download | ||
#' a matching version of BridgeStan to a folder called | ||
#' `.bridgestan` in the user's home directory. | ||
#' | ||
#' See also `set_bridgestan_path` | ||
verify_bridgestan_path <- function(path) { | ||
suppressWarnings({ | ||
folder <- normalizePath(path) | ||
}) | ||
if (!dir.exists(folder)) { | ||
stop(paste0("BridgeStan folder '", folder, "' does not exist!\n", "If you need to set a different location, call 'set_bridgestan_path()'")) | ||
} | ||
makefile <- file.path(folder, "Makefile") | ||
if (!file.exists(makefile)) { | ||
stop(paste0("BridgeStan folder '", folder, "' does not contain file 'Makefile',", | ||
" please ensure it is built properly!\n", "If you need to set a different location, call 'set_bridgestan_path()'")) | ||
} | ||
} | ||
|
||
#' Set the path to BridgeStan. | ||
#' | ||
#' This should point to the top-level folder of the repository. | ||
#' @export | ||
set_bridgestan_path <- function(path) { | ||
verify_bridgestan_path(path) | ||
Sys.setenv(BRIDGESTAN = normalizePath(path)) | ||
} | ||
|
||
get_bridgestan_path <- function() { | ||
# try to get from environment | ||
path <- Sys.getenv("BRIDGESTAN", unset = "") | ||
if (path == "") { | ||
path <- CURRENT_BRIDGESTAN | ||
tryCatch({ | ||
verify_bridgestan_path(path) | ||
}, error = function(e) { | ||
print(paste0("Bridgestan not found at location specified by $BRIDGESTAN ", | ||
"environment variable, downloading version ", packageVersion("bridgestan"), | ||
" to ", path)) | ||
get_bridgestan_src() | ||
}) | ||
} | ||
|
||
return(path) | ||
} | ||
|
||
|
||
#' Run BridgeStan's Makefile on a `.stan` file, creating the `.so` | ||
#' used by the StanModel class. | ||
#' This function checks that the path to BridgeStan is valid | ||
#' and will error if not. This can be set with `set_bridgestan_path`. | ||
#' | ||
#' @param stan_file A path to a Stan model file. | ||
#' @param stanc_arg A vector of arguments to pass to stanc3. | ||
#' For example, `c('--O1')` will enable compiler optimization level 1. | ||
#' @param make_args A vector of additional arguments to pass to Make. | ||
#' For example, `c('STAN_THREADS=True')` will enable | ||
#' threading for the compiled model. If the same flags are defined | ||
#' in `make/local`, the versions passed here will take precedent. | ||
#' @return Path to the compiled model. | ||
#' @export | ||
compile_model <- function(stan_file, stanc_args = NULL, make_args = NULL) { | ||
verify_bridgestan_path(get_bridgestan_path()) | ||
suppressWarnings({ | ||
file_path <- normalizePath(stan_file) | ||
}) | ||
if (tools::file_ext(file_path) != "stan") { | ||
stop(paste0("File '", file_path, "' does not end with '.stan'")) | ||
} | ||
if (!file.exists(file_path)) { | ||
stop(paste0("File '", file_path, "' does not exist!")) | ||
} | ||
|
||
output <- paste0(tools::file_path_sans_ext(file_path), "_model.so") | ||
stancflags <- paste("--include-paths=.", paste(stanc_args, collapse = " ")) | ||
|
||
flags <- c(paste("-C", get_bridgestan_path()), make_args, paste0("STANCFLAGS=\"", | ||
stancflags, "\""), output) | ||
|
||
suppressWarnings({ | ||
res <- system2(MAKE, args = flags, stdout = TRUE, stderr = TRUE) | ||
}) | ||
res_attrs <- attributes(res) | ||
if ("status" %in% names(res_attrs) && res_attrs$status != 0) { | ||
stop(paste0("Compilation failed with error code ", res_attrs$status, "\noutput:\n", | ||
paste(res, collapse = "\n"))) | ||
} | ||
|
||
return(output) | ||
} | ||
|
||
windows_path_setup <- function() { | ||
if (.Platform$OS.type == "windows") { | ||
suppressWarnings(out <- system2("where.exe", "tbb.dll", stdout = NULL, stderr = NULL)) | ||
if (out != 0) { | ||
tbb_path <- file.path(get_bridgestan_path(), "stan", "lib", "stan_math", | ||
"lib", "tbb") | ||
Sys.setenv(PATH = paste(tbb_path, Sys.getenv("PATH"), sep = ";")) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
current_version <- packageVersion("bridgestan") | ||
HOME_BRIDGESTAN <- path.expand(file.path("~", ".bridgestan")) | ||
CURRENT_BRIDGESTAN <- file.path(HOME_BRIDGESTAN, paste0("bridgestan-", current_version)) | ||
|
||
RETRIES <- 5 | ||
|
||
get_bridgestan_src <- function() { | ||
url <- paste0("https://github.com/roualdes/bridgestan/releases/download/", "v", | ||
current_version, "/bridgestan-", current_version, ".tar.gz") | ||
|
||
dir.create(HOME_BRIDGESTAN, showWarnings = FALSE, recursive = TRUE) | ||
temp <- tempfile() | ||
err_text <- paste("Failed to download Bridgestan", current_version, "from github.com.") | ||
for (i in 1:RETRIES) { | ||
tryCatch({ | ||
download.file(url, destfile = temp, mode = "wb", quiet = TRUE, method = "auto") | ||
}, error = function(e) { | ||
cat(err_text, "\n") | ||
if (i == RETRIES) { | ||
stop(err_text, call. = FALSE) | ||
} else { | ||
cat("Retrying (", i + 1, "/", RETRIES, ")...\n", sep = "") | ||
Sys.sleep(1) | ||
} | ||
}) | ||
} | ||
|
||
tryCatch({ | ||
untar(temp, exdir = HOME_BRIDGESTAN) | ||
}, error = function(e) { | ||
stop(paste("Failed to unpack", url, "during installation"), call. = FALSE) | ||
}) | ||
|
||
unlink(temp) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
base = "../../.." | ||
|
||
load_model <- function(name, include_data = TRUE) { | ||
if (include_data) { | ||
data = file.path(base, "test_models", name, paste0(name, ".data.json")) | ||
} else { | ||
data = "" | ||
} | ||
model <- StanModel$new(file.path(base, "test_models", name, paste0(name, "_model.so")), | ||
data, 1234) | ||
return(model) | ||
} | ||
|
||
simple <- load_model("simple") | ||
bernoulli <- load_model("bernoulli") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
|
||
test_that("loading another library didn't break prior ones", { | ||
if (.Platform$OS.type == "windows") { | ||
dll = "./test_collisions.dll" | ||
} else { | ||
dll = "./test_collisions.so" | ||
} | ||
if (file.exists(dll)) { | ||
dyn.load(dll) | ||
expect_equal(bernoulli$name(), "bernoulli_model") | ||
expect_equal(simple$name(), "simple_model") | ||
} | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed it. Where is this function used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's called in the constructor before we load the shared object (same as Python, looks like Julia calls it at startup which is not ideal and I will take a look at with #187)