Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input size checks #241

Merged
merged 3 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions R/R/bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ StanModel <- R6::R6Class("StanModel",
}
private$model <- ret$ptr_out

# pre-compute to avoid repeated work in bounds checks
private$unc_dims <- self$param_unc_num()

model_version <- self$model_version()
if (packageVersion("bridgestan") != paste(model_version$major, model_version$minor, model_version$patch, sep = ".")) {
warning(paste0("The version of the compiled model does not match the version of the R library. ",
Expand Down Expand Up @@ -167,6 +170,9 @@ StanModel <- R6::R6Class("StanModel",
} else {
rng_ptr <- as.raw(rng$ptr)
}
if (length(theta_unc) != private$unc_dims) {
stop("Incorrect number of unconstrained parameters.")
}
vars <- .C("bs_param_constrain_R", as.raw(private$model),
as.logical(include_tp), as.logical(include_gq), as.double(theta_unc),
theta = double(self$param_num(include_tp = include_tp, include_gq = include_gq)),
Expand Down Expand Up @@ -202,7 +208,7 @@ StanModel <- R6::R6Class("StanModel",
param_unconstrain = function(theta) {
vars <- .C("bs_param_unconstrain_R", as.raw(private$model),
as.double(theta),
theta_unc = double(self$param_unc_num()),
theta_unc = double(private$unc_dims),
return_code = as.integer(0),
err_msg = as.character(""),
err_ptr = raw(8),
Expand All @@ -223,7 +229,7 @@ StanModel <- R6::R6Class("StanModel",
param_unconstrain_json = function(json) {
vars <- .C("bs_param_unconstrain_json_R", as.raw(private$model),
as.character(json),
theta_unc = double(self$param_unc_num()),
theta_unc = double(private$unc_dims),
return_code = as.integer(0),
err_msg = as.character(""),
err_ptr = raw(8),
Expand All @@ -241,6 +247,9 @@ StanModel <- R6::R6Class("StanModel",
#' @param jacobian If `TRUE`, include change of variables terms for constrained parameters.
#' @return The log density.
log_density = function(theta_unc, propto = TRUE, jacobian = TRUE) {
if (length(theta_unc) != private$unc_dims) {
stop("Incorrect number of unconstrained parameters.")
}
vars <- .C("bs_log_density_R", as.raw(private$model),
as.logical(propto), as.logical(jacobian), as.double(theta_unc),
val = double(1),
Expand All @@ -262,7 +271,10 @@ StanModel <- R6::R6Class("StanModel",
#' @param jacobian If `TRUE`, include change of variables terms for constrained parameters.
#' @return List containing entries `val` (the log density) and `gradient` (the gradient).
log_density_gradient = function(theta_unc, propto = TRUE, jacobian = TRUE) {
dims <- self$param_unc_num()
if (length(theta_unc) != private$unc_dims) {
stop("Incorrect number of unconstrained parameters.")
}
dims <- private$unc_dims
vars <- .C("bs_log_density_gradient_R", as.raw(private$model),
as.logical(propto), as.logical(jacobian), as.double(theta_unc),
val = double(1), gradient = double(dims),
Expand All @@ -284,7 +296,10 @@ StanModel <- R6::R6Class("StanModel",
#' @param jacobian If `TRUE`, include change of variables terms for constrained parameters.
#' @return List containing entries `val` (the log density), `gradient` (the gradient), and `hessian` (the Hessian).
log_density_hessian = function(theta_unc, propto = TRUE, jacobian = TRUE) {
dims <- self$param_unc_num()
if (length(theta_unc) != private$unc_dims) {
stop("Incorrect number of unconstrained parameters.")
}
dims <- private$unc_dims
vars <- .C("bs_log_density_hessian_R", as.raw(private$model),
as.logical(propto), as.logical(jacobian), as.double(theta_unc),
val = double(1), gradient = double(dims), hess = double(dims * dims),
Expand All @@ -308,7 +323,12 @@ StanModel <- R6::R6Class("StanModel",
#' @param jacobian If `TRUE`, include change of variables terms for constrained parameters.
#' @return List containing entries `val` (the log density) and `Hvp` (the hessian-vector product).
log_density_hessian_vector_product = function(theta_unc, v, propto = TRUE, jacobian = TRUE){
dims <- self$param_unc_num()
dims <- private$unc_dims
if (length(theta_unc) != dims) {
stop("Incorrect number of unconstrained parameters.")
} else if (length(v) != dims) {
stop("Incorrect number of vector elements.")
}
vars <- .C("bs_log_density_hessian_vector_product_R",
as.raw(private$model), as.logical(propto), as.logical(jacobian),
as.double(theta_unc),
Expand All @@ -331,6 +351,7 @@ StanModel <- R6::R6Class("StanModel",
lib_name = NA,
model = NA,
seed = NA,
unc_dims = NA,
finalize = function() {
.C("bs_model_destruct_R",
as.raw(private$model),
Expand Down
2 changes: 2 additions & 0 deletions R/tests/testthat/test_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ test_that("param_constrain handles rng arguments", {

# require at least one present
expect_error(full$param_constrain(c(1.2), include_gq = TRUE), "rng must be provided")

expect_error(full$param_constrain(c(1.2, 1.2)), "Incorrect number of unconstrained parameters")
})


Expand Down
57 changes: 40 additions & 17 deletions julia/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ mutable struct StanModel
stanmodel::Ptr{StanModelStruct}
@const data::String
@const seed::UInt32
@const param_unc_num::Int

function StanModel(
lib::String,
Expand Down Expand Up @@ -85,7 +86,11 @@ mutable struct StanModel
error(handle_error(lib, err, "bs_model_construct"))
end

sm = new(lib, stanmodel, data, seed)
# compute now to avoid re-computing in bounds checks later
param_unc_num =
@ccall $(dlsym(lib, :bs_param_unc_num))(stanmodel::Ptr{StanModelStruct})::Cint

sm = new(lib, stanmodel, data, seed, param_unc_num)

function f(sm)
@ccall $(dlsym(sm.lib, :bs_model_destruct))(
Expand Down Expand Up @@ -279,7 +284,13 @@ function param_constrain!(
rng::Union{StanRNG,Nothing} = nothing,
)
dims = param_num(sm; include_tp = include_tp, include_gq = include_gq)
if length(out) != dims
if length(theta_unc) != sm.param_unc_num
throw(
DimensionMismatch(
"theta_unc must be same size as number of unconstrained parameters",
),
)
elseif length(out) != dims
throw(
DimensionMismatch("out must be same size as number of constrained parameters"),
)
Expand Down Expand Up @@ -359,8 +370,7 @@ The result is stored in the vector `out`, and a reference is returned. See
This is the inverse of [`param_constrain!`](@ref).
"""
function param_unconstrain!(sm::StanModel, theta::Vector{Float64}, out::Vector{Float64})
dims = param_unc_num(sm)
if length(out) != dims
if length(out) != sm.param_unc_num
throw(
DimensionMismatch(
"out must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -396,7 +406,7 @@ re-using existing memory.
This is the inverse of [`param_constrain`](@ref).
"""
function param_unconstrain(sm::StanModel, theta::Vector{Float64})
out = zeros(param_unc_num(sm))
out = zeros(sm.param_unc_num)
param_unconstrain!(sm, theta, out)
end

Expand All @@ -411,8 +421,7 @@ The result is stored in the vector `out`, and a reference is returned. See
[`param_unconstrain_json`](@ref) for a version which allocates fresh memory.
"""
function param_unconstrain_json!(sm::StanModel, theta::String, out::Vector{Float64})
dims = param_unc_num(sm)
if length(out) != dims
if length(out) != sm.param_unc_num
throw(
DimensionMismatch(
"out must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -445,7 +454,7 @@ See [`param_unconstrain_json!`](@ref) for a version which allows
re-using existing memory.
"""
function param_unconstrain_json(sm::StanModel, theta::String)
out = zeros(param_unc_num(sm))
out = zeros(sm.param_unc_num)
param_unconstrain_json!(sm, theta, out)
end

Expand Down Expand Up @@ -498,8 +507,11 @@ function log_density_gradient!(
propto::Bool = true,
jacobian::Bool = true,
)
dims = param_unc_num(sm)
if length(out) != dims
if length(q) != sm.param_unc_num
throw(
DimensionMismatch("q must be same size as number of unconstrained parameters"),
)
elseif length(out) != sm.param_unc_num
throw(
DimensionMismatch(
"out must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -541,7 +553,7 @@ function log_density_gradient(
propto::Bool = true,
jacobian::Bool = true,
)
grad = zeros(param_unc_num(sm))
grad = zeros(sm.param_unc_num)
log_density_gradient!(sm, q, grad; propto = propto, jacobian = jacobian)
end

Expand All @@ -565,8 +577,12 @@ function log_density_hessian!(
propto::Bool = true,
jacobian::Bool = true,
)
dims = param_unc_num(sm)
if length(out_grad) != dims
dims = sm.param_unc_num
if length(q) != dims
throw(
DimensionMismatch("q must be same size as number of unconstrained parameters"),
)
elseif length(out_grad) != dims
throw(
DimensionMismatch(
"out_grad must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -615,7 +631,7 @@ function log_density_hessian(
propto::Bool = true,
jacobian::Bool = true,
)
dims = param_unc_num(sm)
dims = sm.param_unc_num
grad = zeros(dims)
hess = zeros(dims * dims)
log_density_hessian!(sm, q, grad, hess; propto = propto, jacobian = jacobian)
Expand All @@ -641,8 +657,15 @@ function log_density_hessian_vector_product!(
propto::Bool = true,
jacobian::Bool = true,
)
dims = param_unc_num(sm)
if length(out) != dims
if length(q) != sm.param_unc_num
throw(
DimensionMismatch("q must be same size as number of unconstrained parameters"),
)
elseif length(v) != sm.param_unc_num
throw(
DimensionMismatch("v must be same size as number of unconstrained parameters"),
)
elseif length(out) != sm.param_unc_num
throw(
DimensionMismatch(
"out must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -687,7 +710,7 @@ function log_density_hessian_vector_product(
propto::Bool = true,
jacobian::Bool = true,
)
out = zeros(param_unc_num(sm))
out = zeros(sm.param_unc_num)
log_density_hessian_vector_product!(sm, q, v, out; propto = propto, jacobian = jacobian)
end

Expand Down
14 changes: 14 additions & 0 deletions julia/test/model_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ end


model2 = load_test_model("full", false)
a = randn(BridgeStan.param_unc_num(model2))
rng = StanRNG(model2, 1234)
@test 1 == length(BridgeStan.param_constrain(model2, a))
@test 2 == length(BridgeStan.param_constrain(model2, a; include_tp = true))
Expand Down Expand Up @@ -392,6 +393,12 @@ end
jacobian = true,
)

y_unc_bad = zeros(length(y_unc) + 1)
@test_throws DimensionMismatch BridgeStan.log_density_gradient(model, y_unc_bad)

y_unc_bad = zeros(length(y_unc) - 1)
@test_throws DimensionMismatch BridgeStan.log_density_gradient(model, y_unc_bad)

end

@testset "log_density_hessian" begin
Expand Down Expand Up @@ -473,6 +480,13 @@ end
jacobian = true,
)


y_unc_bad = zeros(length(y_unc) + 1)
@test_throws DimensionMismatch BridgeStan.log_density_hessian(model, y_unc_bad)

y_unc_bad = zeros(length(y_unc) - 1)
@test_throws DimensionMismatch BridgeStan.log_density_hessian(model, y_unc_bad)

end

end
Expand Down
18 changes: 12 additions & 6 deletions python/bridgestan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def __init__(

num_params = self._param_unc_num(self.model)

param_sized_array = array_ptr(
dtype=ctypes.c_double,
flags=("C_CONTIGUOUS",),
shape=(num_params,),
)

param_sized_out_array = array_ptr(
dtype=ctypes.c_double,
flags=("C_CONTIGUOUS", "WRITEABLE"),
Expand Down Expand Up @@ -227,7 +233,7 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
param_sized_array,
writeable_double_array,
ctypes.c_void_p,
star_star_char,
Expand Down Expand Up @@ -257,7 +263,7 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
param_sized_array,
ctypes.POINTER(ctypes.c_double),
star_star_char,
]
Expand All @@ -268,7 +274,7 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
param_sized_array,
ctypes.POINTER(ctypes.c_double),
param_sized_out_array,
star_star_char,
Expand All @@ -280,7 +286,7 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
param_sized_array,
ctypes.POINTER(ctypes.c_double),
param_sized_out_array,
param_sqrd_sized_out_array,
Expand All @@ -293,8 +299,8 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
double_array,
param_sized_array,
param_sized_array,
ctypes.POINTER(ctypes.c_double),
param_sized_out_array,
star_star_char,
Expand Down
Loading
Loading