I have recently become a huge fan of Bayesian statistics. It makes so much sense and if you are doing any kind of inferences from data, you should check it out, especially the Stan language. Without much further intro, this is my first blog on statistical topics.
Not long ago, I came across a nice blogpost by Kahtryn Morrison called A gentle INLA tutorial. The blog was nice and helped me better appreciate INLA. But as a fan of the Stan probabilistic language, I felt that comparing INLA to JAGS is not really that relevant, as Stan should - at least in theory - be way faster and better than JAGS. Here, I ran a comparison of INLA to Stan on the second example called “Poisson GLM with an iid random effect”.
The TLDR is: For this model, Stan scales considerably better than JAGS, but still cannot scale to very large model. Also, for this model Stan and INLA give almost the same results. It seems that Stan becomes useful only when your model cannot be coded in INLA.
Pleas let me know (via an issue on GitHub) should you find any error or anything else that should be included in this post. Also, if you run the experiment on a different machine and/or with different seed, let me know the results.
Here are the original numbers from Kathryn’s blog:
N | kathryn_rjags | kathryn_rinla |
---|---|---|
100 | 30.394 | 0.383 |
500 | 142.532 | 1.243 |
5000 | 1714.468 | 5.768 |
25000 | 8610.32 | 30.077 |
100000 | got bored after 6 hours | 166.819 |
Full source of this post is available at this blog’s Github repo. Keep in mind that installing RStan is unfortunately not as straightforward as running install.packages. Please consult https://github.com/stan-dev/rstan/wiki/RStan-Getting-Started if you don’t have RStan already installed.
The model
The model we are interested in is a simple GLM with partial pooling of a random effect:
y_i ~ poisson(mu_i)
log(mu_i) ~ alpha + beta * x_i + nu_i
nu_i ~ normal(0, tau_nu)
The comparison
Let’s setup our libraries.
library(rstan)
library(brms)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
library(INLA)
library(tidyverse)
set.seed(6619414)
The results are stored in files within the repository to let me rebuild the site with blogdown easily. Delete cache directory to force a complete rerun.
cache_dir = "_stan_vs_inla_cache/"
if(!dir.exists(cache_dir)){
dir.create(cache_dir)
}
Let’s start by simulating data
#The sizes of datasets to work with
N_values = c(100, 500, 5000, 25000)
data = list()
for(N in N_values) {
x = rnorm(N, mean=5,sd=1)
nu = rnorm(N,0,0.1)
mu = exp(1 + 0.5*x + nu)
y = rpois(N,mu)
data[[N]] = list(
N = N,
x = x,
y = y
)
}
Measuring Stan
Here is the model code in Stan (it is good practice to put it into a file, but I wanted to make this post self-contained). It is almost 1-1 rewrite of the original JAGS code, with few important changes:
- JAGS parametrizes normal distribution via precision, Stan via sd. The model recomputes precision to sd.
- I added the ability to explicitly set parameters of the prior distributions as data which is useful later in this post
- With multilevel models, Stan works waaaaaay better with so-called non-centered parametrization. This means that instead of having
nu ~ N(0, nu_sigma), mu = alpha + beta * x + nu
we havenu_normalized ~ N(0,1), mu = alpha + beta * x + nu_normalized * nu_sigma
. This gives exactly the same inferences, but results in a geometry that Stan can explore efficiently.
There are also packages to let you specify common models (including this one) without writing Stan code, using syntax similar to R-INLA - checkout rstanarm and brms. The latter is more flexible, while the former is easier to install, as it does not depend on rstan and can be installed simply with install.packages
.
Note also that Stan developers would suggest against Gamma(0.01,0.01) prior on precision in favor of normal or Cauchy distribution on sd, see https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations.
model_code = "
data {
int N;
vector[N] x;
int y[N];
//Allowing to parametrize the priors (useful later)
real alpha_prior_mean;
real beta_prior_mean;
real<lower=0> alpha_beta_prior_precision;
real<lower=0> tau_nu_prior_shape;
real<lower=0> tau_nu_prior_rate;
}
transformed data {
//Stan parametrizes normal with sd not precision
real alpha_beta_prior_sigma = sqrt(1 / alpha_beta_prior_precision);
}
parameters {
real alpha;
real beta;
vector[N] nu_normalized;
real<lower=0> tau_nu;
}
model {
real nu_sigma = sqrt(1 / tau_nu);
vector[N] nu = nu_normalized * nu_sigma;
//taking advantage of Stan's implicit vectorization here
nu_normalized ~ normal(0,1);
//The built-in poisson_log(x) === poisson(exp(x))
y ~ poisson_log(alpha + beta*x + nu);
alpha ~ normal(alpha_prior_mean, alpha_beta_prior_sigma);
beta ~ normal(beta_prior_mean, alpha_beta_prior_sigma);
tau_nu ~ gamma(tau_nu_prior_shape,tau_nu_prior_rate);
}
//Uncomment this to have the model generate mu values as well
//Currently commented out as storing the samples of mu consumes
//a lot of memory for the big models
/*
generated quantities {
vector[N] mu = exp(alpha + beta*x + nu_normalized * nu_sigma);
}
*/
"
model = stan_model(model_code = model_code)
Below is the code to make the actual measurements. Some caveats:
- The compilation of the Stan model is not counted (you can avoid it with rstanarm and need to do it only once otherwise)
- There is some overhead in transferring the posterior samples from Stan to R. This overhead is non-negligible for the larger models, but you can get rid of it by storing the samples in a file and reading them separately. The overhead is not measured here.
- Stan took > 16 hours to converge for the largest data size (1e5) and then I had issues fitting the posterior samples into memory on my computer. Notably, R-Inla also crashed on my computer for this size. The largest size is thus excluded here, but I have to conclude that if you get bored after 6 hours, Stan is not practical for such a big model.
- I was not able to get rjags running in a reasonable amount of time, so I did not rerun the JAGS version of the model.
stan_times_file = paste0(cache_dir, "stan_times.csv")
stan_summary_file = paste0(cache_dir, "stan_summary.csv")
run_stan = TRUE
if(file.exists(stan_times_file) && file.exists(stan_summary_file)) {
stan_times = read.csv(stan_times_file)
stan_summary = read.csv(stan_summary_file)
if(setequal(stan_times$N, N_values) && setequal(stan_summary$N, N_values)) {
run_stan = FALSE
}
}
if(run_stan) {
stan_times_values = numeric(length(N_values))
stan_summary_list = list()
step = 1
for(N in N_values) {
data_stan = data[[N]]
data_stan$alpha_prior_mean = 0
data_stan$beta_prior_mean = 0
data_stan$alpha_beta_prior_precision = 0.001
data_stan$tau_nu_prior_shape = 0.01
data_stan$tau_nu_prior_rate = 0.01
fit = sampling(model, data = data_stan);
stan_summary_list[[step]] =
as.data.frame(
rstan::summary(fit, pars = c("alpha","beta","tau_nu"))$summary
) %>% rownames_to_column("parameter")
stan_summary_list[[step]]$N = N
all_times = get_elapsed_time(fit)
stan_times_values[step] = max(all_times[,"warmup"] + all_times[,"sample"])
step = step + 1
}
stan_times = data.frame(N = N_values, stan_time = stan_times_values)
stan_summary = do.call(rbind, stan_summary_list)
write.csv(stan_times, stan_times_file,row.names = FALSE)
write.csv(stan_summary, stan_summary_file,row.names = FALSE)
}
Measuring INLA
inla_times_file = paste0(cache_dir,"inla_times.csv")
inla_summary_file = paste0(cache_dir,"inla_summary.csv")
run_inla = TRUE
if(file.exists(inla_times_file) && file.exists(inla_summary_file)) {
inla_times = read.csv(inla_times_file)
inla_summary = read.csv(inla_summary_file)
if(setequal(inla_times$N, N_values) && setequal(inla_summary$N, N_values)) {
run_inla = FALSE
}
}
if(run_inla) {
inla_times_values = numeric(length(N_values))
inla_summary_list = list()
step = 1
for(N in N_values) {
nu = 1:N
fit_inla = inla(y ~ x + f(nu,model="iid"), family = c("poisson"),
data = data[[N]], control.predictor=list(link=1))
inla_times_values[step] = fit_inla$cpu.used["Total"]
inla_summary_list[[step]] =
rbind(fit_inla$summary.fixed %>% select(-kld),
fit_inla$summary.hyperpar) %>%
rownames_to_column("parameter")
inla_summary_list[[step]]$N = N
step = step + 1
}
inla_times = data.frame(N = N_values, inla_time = inla_times_values)
inla_summary = do.call(rbind, inla_summary_list)
write.csv(inla_times, inla_times_file,row.names = FALSE)
write.csv(inla_summary, inla_summary_file,row.names = FALSE)
}
Checking inferences
Here we see side-by-side comparisons of the inferences and they seem pretty comparable between Stan and Inla:
for(N_to_show in N_values) {
print(kable(stan_summary %>% filter(N == N_to_show) %>%
select(c("parameter","mean","sd")),
caption = paste0("Stan results for N = ", N_to_show)))
print(kable(inla_summary %>% filter(N == N_to_show) %>%
select(c("parameter","mean","sd")),
caption = paste0("INLA results for N = ", N_to_show)))
}
Table 1: Stan results for N = 100| parameter | mean | sd |
| --- | --- | --- |
| alpha | 1.013559 | 0.0989778 |
| beta | 0.495539 | 0.0176988 |
| tau_nu | 162.001608 | 82.7700473 |
Table 1: INLA results for N = 100| parameter | mean | sd |
| --- | --- | --- |
| (Intercept) | 1.009037e+00 | 9.15248e-02 |
| x | 4.971302e-01 | 1.61486e-02 |
| Precision for nu | 1.819654e+04 | 1.71676e+04 |
Table 1: Stan results for N = 500| parameter | mean | sd |
| --- | --- | --- |
| alpha | 1.0046284 | 0.0555134 |
| beta | 0.4977522 | 0.0102697 |
| tau_nu | 71.6301530 | 13.8264812 |
Table 1: INLA results for N = 500| parameter | mean | sd |
| --- | --- | --- |
| (Intercept) | 1.0053202 | 0.0538456 |
| x | 0.4977124 | 0.0099593 |
| Precision for nu | 77.3311793 | 16.0255430 |
Table 1: Stan results for N = 5000| parameter | mean | sd |
| --- | --- | --- |
| alpha | 1.009930 | 0.0159586 |
| beta | 0.496859 | 0.0029250 |
| tau_nu | 101.548580 | 7.4655716 |
Table 1: INLA results for N = 5000| parameter | mean | sd |
| --- | --- | --- |
| (Intercept) | 1.0099282 | 0.0155388 |
| x | 0.4968718 | 0.0028618 |
| Precision for nu | 103.1508773 | 7.6811740 |
Table 1: Stan results for N = 25000| parameter | mean | sd |
| --- | --- | --- |
| alpha | 0.9874707 | 0.0066864 |
| beta | 0.5019566 | 0.0012195 |
| tau_nu | 104.3599424 | 3.5391938 |
Table 1: INLA results for N = 25000| parameter | mean | sd |
| --- | --- | --- |
| (Intercept) | 0.9876218 | 0.0067978 |
| x | 0.5019341 | 0.0012452 |
| Precision for nu | 104.8948949 | 3.4415929 |
Summary of the timing
You can see that Stan keeps reasonable runtimes for longer time than JAGS in the original blog post, but INLA is still way faster. Also Kathryn got probably very lucky with her seed for N = 25 000, as her INLA run completed very quickly. With my (few) tests, INLA always took at least several minutes for N = 25 000. It may mean that Kathryn’s JAGS time is also too short.
my_results = merge.data.frame(inla_times, stan_times, by = "N")
kable(merge.data.frame(my_results, kathryn_results, by = "N"))
N | inla_time | stan_time | kathryn_rjags | kathryn_rinla |
---|---|---|---|---|
100 | 1.061742 | 1.885 | 30.394 | 0.383 |
500 | 1.401597 | 11.120 | 142.532 | 1.243 |
5000 | 10.608704 | 388.514 | 1714.468 | 5.768 |
25000 | 611.505543 | 5807.670 | 8610.32 | 30.077 |
You could obviously do multiple runs to reduce uncertainty etc., but this post has already taken too much time of mine, so this will be left to others.
Testing quality of the results
I also had a hunch that maybe INLA is less precise than Stan, but that turned out to be based on an error. Thus, without much commentary, I put here my code to test this. Basically, I modify the random data generator to actually draw from priors (those priors are quite constrained to provide similar values of alpha, beta nad tau_nu as in the original). I than give both algorithms the knowledge of these priors. I compute both difference between true parameters and a point estimate (mean) and quantiles of the posterior distribution where the true parameter is found. If the algorithms give the best possible estimates, the distribution of such quantiles should be uniform over (0,1). Turns out INLA and Stan give almost exactly the same results for almost all runs and the differences in quality are (for this particular model) negligible.
test_precision = function(N) {
rejects <- 0
repeat {
#Set the priors so that they generate similar parameters as in the example above
alpha_beta_prior_precision = 5
prior_sigma = sqrt(1/alpha_beta_prior_precision)
alpha_prior_mean = 1
beta_prior_mean = 0.5
alpha = rnorm(1, alpha_prior_mean, prior_sigma)
beta = rnorm(1, beta_prior_mean, prior_sigma)
tau_nu_prior_shape = 2
tau_nu_prior_rate = 0.01
tau_nu = rgamma(1,tau_nu_prior_shape,tau_nu_prior_rate)
sigma_nu = sqrt(1 / tau_nu)
x = rnorm(N, mean=5,sd=1)
nu = rnorm(N,0,sigma_nu)
linear = alpha + beta*x + nu
#Rejection sampling to avoid NAs and ill-posed problems
if(max(linear) < 15) {
mu = exp(linear)
y = rpois(N,mu)
if(mean(y == 0) < 0.7) {
break;
}
}
rejects = rejects + 1
}
#cat(rejects, "rejects\n")
data = list(
N = N,
x = x,
y = y
)
#cat("A:",alpha,"B:", beta, "T:", tau_nu,"\n")
#print(linear)
#print(data)
#=============== Fit INLA
nu = 1:N
fit_inla = inla(y ~ x + f(nu,model="iid",
hyper=list(theta=list(prior="loggamma",
param=c(tau_nu_prior_shape,tau_nu_prior_rate)))),
family = c("poisson"),
control.fixed = list(mean = beta_prior_mean,
mean.intercept = alpha_prior_mean,
prec = alpha_beta_prior_precision,
prec.intercept = alpha_beta_prior_precision
),
data = data, control.predictor=list(link=1)
)
time_inla = fit_inla$cpu.used["Total"]
alpha_mean_diff_inla = fit_inla$summary.fixed["(Intercept)","mean"] - alpha
beta_mean_diff_inla = fit_inla$summary.fixed["x","mean"] - beta
tau_nu_mean_diff_inla = fit_inla$summary.hyperpar[,"mean"] - tau_nu
alpha_q_inla = inla.pmarginal(alpha, fit_inla$marginals.fixed$`(Intercept)`)
beta_q_inla = inla.pmarginal(beta, fit_inla$marginals.fixed$x)
tau_nu_q_inla = inla.pmarginal(tau_nu, fit_inla$marginals.hyperpar$`Precision for nu`)
#================ Fit Stan
data_stan = data
data_stan$alpha_prior_mean = alpha_prior_mean
data_stan$beta_prior_mean = beta_prior_mean
data_stan$alpha_beta_prior_precision = alpha_beta_prior_precision
data_stan$tau_nu_prior_shape = tau_nu_prior_shape
data_stan$tau_nu_prior_rate = tau_nu_prior_rate
fit = sampling(model, data = data_stan, control = list(adapt_delta = 0.95));
all_times = get_elapsed_time(fit)
max_total_time_stan = max(all_times[,"warmup"] + all_times[,"sample"])
samples = rstan::extract(fit, pars = c("alpha","beta","tau_nu"))
alpha_mean_diff_stan = mean(samples$alpha) - alpha
beta_mean_diff_stan = mean(samples$beta) - beta
tau_nu_mean_diff_stan = mean(samples$tau_nu) - tau_nu
alpha_q_stan = ecdf(samples$alpha)(alpha)
beta_q_stan = ecdf(samples$beta)(beta)
tau_nu_q_stan = ecdf(samples$tau_nu)(tau_nu)
return(data.frame(time_rstan = max_total_time_stan,
time_rinla = time_inla,
alpha_mean_diff_stan = alpha_mean_diff_stan,
beta_mean_diff_stan = beta_mean_diff_stan,
tau_nu_mean_diff_stan = tau_nu_mean_diff_stan,
alpha_q_stan = alpha_q_stan,
beta_q_stan = beta_q_stan,
tau_nu_q_stan = tau_nu_q_stan,
alpha_mean_diff_inla = alpha_mean_diff_inla,
beta_mean_diff_inla = beta_mean_diff_inla,
tau_nu_mean_diff_inla = tau_nu_mean_diff_inla,
alpha_q_inla= alpha_q_inla,
beta_q_inla = beta_q_inla,
tau_nu_q_inla = tau_nu_q_inla
))
}
Actually running the comparison. On some occasions, Stan does not converge, my best guess is that the data are somehow pathological, but I didn’t investigate thoroughly. You see that results for Stan and Inla are very similar both as point estimates and the distribution of posterior quantiles. The accuracy of the INLA approximation is also AFAIK going to improve with more data.
library(skimr) #Uses skimr to summarize results easily
precision_results_file = paste0(cache_dir,"precision_results.csv")
if(file.exists(precision_results_file)) {
results_precision_df = read.csv(precision_results_file)
} else {
results_precision = list()
for(i in 1:100) {
results_precision[[i]] = test_precision(50)
}
results_precision_df = do.call(rbind, results_precision)
write.csv(results_precision_df,precision_results_file,row.names = FALSE)
}
#Remove uninteresting skim statistics
skim_with(numeric = list(missing = NULL, complete = NULL, n = NULL))
skimmed = results_precision_df %>% select(-X) %>% skim()
#Now a hack to display skim histograms properly in the output:
skimmed_better = skimmed %>% rowwise() %>% mutate(formatted =
if_else(stat == "hist",
utf8ToInt(formatted) %>% as.character() %>% paste0("&#", . ,";", collapse = ""),
formatted))
mostattributes(skimmed_better) = attributes(skimmed)
skimmed_better %>% kable(escape = FALSE)
Skim summary statistics
n obs: 100
n variables: 14
Variable type: numeric
variable | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|
alpha_mean_diff_inla | -0.0021 | 0.2 | -0.85 | -0.094 | 0.0023 | 0.095 | 0.53 | ▁▁▁▂▇▇▁▁ |
alpha_mean_diff_stan | -0.0033 | 0.2 | -0.84 | -0.097 | -0.00012 | 0.093 | 0.52 | ▁▁▁▂▇▇▁▂ |
alpha_q_inla | 0.5 | 0.29 | 0.00084 | 0.25 | 0.5 | 0.73 | 0.99 | ▅▇▇▆▇▆▆▇ |
alpha_q_stan | 0.5 | 0.28 | 0.001 | 0.26 | 0.5 | 0.73 | 0.99 | ▅▇▇▆▇▆▆▇ |
beta_mean_diff_inla | -0.00088 | 0.04 | -0.12 | -0.016 | -0.001 | 0.014 | 0.17 | ▁▁▃▇▂▁▁▁ |
beta_mean_diff_stan | -0.001 | 0.04 | -0.12 | -0.016 | -5e-04 | 0.014 | 0.16 | ▁▁▂▇▂▁▁▁ |
beta_q_inla | 0.51 | 0.28 | 0.0068 | 0.26 | 0.52 | 0.75 | 1 | ▆▆▅▆▇▅▆▆ |
beta_q_stan | 0.51 | 0.28 | 0.0065 | 0.27 | 0.51 | 0.75 | 1 | ▆▆▅▇▆▅▆▆ |
tau_nu_mean_diff_inla | 4.45 | 90.17 | -338.58 | -26.74 | 4.49 | 53.38 | 193 | ▁▁▁▂▅▇▃▂ |
tau_nu_mean_diff_stan | 5.21 | 90 | -339.89 | -24.62 | 4.29 | 54.48 | 191.94 | ▁▁▁▂▅▇▃▂ |
tau_nu_q_inla | 0.53 | 0.26 | 0.023 | 0.32 | 0.52 | 0.74 | 0.99 | ▃▅▆▆▇▆▅▅ |
tau_nu_q_stan | 0.53 | 0.26 | 0.021 | 0.32 | 0.53 | 0.75 | 0.99 | ▃▅▅▆▇▃▅▅ |
time_rinla | 0.97 | 0.093 | 0.86 | 0.91 | 0.93 | 0.98 | 1.32 | ▇▇▂▁▁▁▁▁ |
time_rstan | 1.79 | 1.4 | 0.55 | 0.89 | 1.45 | 2.09 | 10.04 | ▇▂▁▁▁▁▁▁ |
Top comments (0)