DEV Community

Martin Modrák
Martin Modrák

Posted on • Originally published at martinmodrak.cz on

Optional Parameters/Data in Stan

Sometimes you are developing a Stan statistical model that has multiple variants: maybe you want to consider several different link functions somewhere deep in your model, or you want to switch between estimating a quantity and getting it as data or something completely different. In these cases, you might have wanted to use optional parameters and/or data that apply only to some variants of your model. Sadly, Stan does not support this feature directly, but you can implement it yourself with just a bit of additional code. In this post I will show how.

The Base Model

Let’s start with a very simple model: just estimating the mean and standard deviation of a normal distribution:

library(rstan)
library(knitr)
library(tidyverse)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
set.seed(3145678)

model_fixed_code <- "
data {
  int N;
  vector[N] X;
}

parameters {
  real mu;
  real<lower=0> sigma; 
}

model {
  X ~ normal(mu, sigma);

  //And some priors
  mu ~ normal(0, 10);
  sigma ~ student_t(3, 0, 1);
}

"

model_fixed <- stan_model(model_code = model_fixed_code)

And let’s simulate some data and see that it fits:

mu_true = 8
sigma_true = 2
N = 10
X <- rnorm(N, mean = mu_true, sd = sigma_true)

data_fixed <- list(N = N, X = X)
fit_fixed <- sampling(model_fixed, data = data_fixed, iter = 500)
summary(fit_fixed, probs = c(0.1, 0.9))$summary %>% kable()
mean se_mean sd 10% 90% n_eff Rhat
mu 7.855031 0.0256139 0.5632183 7.162485 8.548415 483.5059 1.007501
sigma 1.774158 0.0206974 0.4400573 1.302616 2.350727 452.0508 1.003409
lp__ -12.103350 0.0555738 1.1132479 -13.664610 -11.091775 401.2768 1.004955

Now With Optional Parameters

Let’s say we now want to handle the case where the standard deviation is known. Obviously we could write a new model. But what if the full model has several hundred lines and the only thing we want to change is to let the user specify the known standard deviation? The simplest solution is to just have all parameters/data that are needed in any of the variants lying around and use if conditions in the model block to ignore some of them, but that is a bit unsatisfactory (and also those unused parameters may in some cases hinder sampling).

For a better solution, we can take advantage of the fact that Stan allows zero-sized arrays/vectors and features the ternary operator ?. The ternary operator has the syntax (condition) ? (true value) : (false value) and works like an if - else statement, but within an expression. The last piece of the puzzle is that Stan allows size of data and parameter arrays to depend on arbitrary expressions computed from data. The model that can handle both known and unknown standard deviation follows:

model_optional_code <- "
data {
  int N;
  vector[N] X;

  //Just a verbose way to specify boolean variable
  int<lower = 0, upper = 1> sigma_known; 

  //sigma_data is size 0 if sigma_known is FALSE
  real<lower=0> sigma_data[sigma_known ? 1 : 0]; 
}

parameters {
  real mu;

  //sigma is size 0 if sigma_known is TRUE
  real<lower=0> sigma_param[sigma_known ? 0 : 1]; 
}

transformed parameters {
  real<lower=0> sigma;
  if (sigma_known) {
    sigma = sigma_data[1];
  } else {
    sigma = sigma_param[1];
  }
}

model {
  X ~ normal(mu, sigma);

  //And some priors
  mu ~ normal(0, 10);
  if (!sigma_known) {
    sigma_param ~ student_t(3, 0, 1);
  }
}

"

model_optional <- stan_model(model_code = model_optional_code)

We had to add some biolerplate code, but now we don’t have to maintain two separate models. This trick is also sometimes useful if you want to test multiple variants in development. As the model compiles only once and then you can test the two variants while modifying other parts of your code and reduce time waiting for compilation.

Just to make sure the model works and see how to correctly specify the data, let’s fit it assuming the standard deviation is to be estimated:

data_optional <- list(
  N = N,
  X = X,
  sigma_known = 0,
  sigma_data = numeric(0) #This produces an array of size 0
)

fit_optional <- sampling(model_optional, 
                         data = data_optional, 
                         iter = 500, pars = c("mu","sigma"))
summary(fit_optional, probs = c(0.1, 0.9))$summary %>% kable()
mean se_mean sd 10% 90% n_eff Rhat
mu 7.854036 0.0198265 0.5440900 7.181837 8.531780 753.0924 0.9981102
sigma 1.730077 0.0152808 0.3918781 1.308565 2.270505 657.6701 0.9989029
lp__ -11.992770 0.0503044 0.9811551 -13.383729 -11.089657 380.4199 1.0016842

And now let’s run the model and give it the correct standard deviation:

data_optional_sigma_known <- list(
  N = N,
  X = X,
  sigma_known = 1,
  sigma_data = array(sigma_true, 1) 
  #The array conversion is necessary, otherwise Stan complains about dimensions
)

fit_optional_sigma_known <- sampling(model_optional, 
                                     data = data_optional_sigma_known, 
                                     iter = 500, pars = c("mu","sigma"))
summary(fit_optional_sigma_known, probs = c(0.1, 0.9))$summary %>% kable()
mean se_mean sd 10% 90% n_eff Rhat
mu 7.808058 0.0292710 0.6273565 7.017766 8.622762 459.3600 1.006164
sigma 2.000000 0.0000000 0.0000000 2.000000 2.000000 1000.0000 NaN
lp__ -11.072234 0.0321233 0.6750295 -11.917321 -10.585280 441.5753 1.002187

Extending

Obviously this method lets you do all sorts of more complicated things, in particular:

  • When the optional parameter is a vector you can have something like

vector[sigma_known ? 0 : n_sigma] sigma;

  • You can have more than two variants to choose from and then use something akin to

real param[varaint == 5 ? 0 : 1];

  • If your conditions become more complex you can always put them into a user-defined function (for optional data) or transformed data block (for optional parameters) as in:
functions {
  int compute_whatever_size(int X, int Y, int Z) {
        //do stuff
  }
}

data {
  ...
  real whatever[compute_whatever_size(X,Y,Z)];
  real<lower = 0> whatever_sigma[compute_whatever_size(X,Y,Z)];
}

transformed data {
  int carebear_size;

  //do stuff
  carebear_size = magic_result;
}

parameters {
  vector[carebear_size] carebear;
  matrix[carebear_size,carebear_size] spatial_carebear;
}

Top comments (0)