Making Predictions from Stan models in R

Background

Stan (http://mc-stan.org) is a probabilistic programming language for estimating flexible statistical models. If you’re interested in Bayesian modeling, you usually don’t have to look further than Stan. It has almost everything you’ll need to define arbitrarily complex models, explicitly specify prior distributions, and diagnose model performance.

One area where Stan is lacking, however, is reusing estimated models for predictions on new data. Often we fit a model y ∼ x and need to save the model for use as new x become available. This might be for monthly report, a production system where real-time predictions are necessary, or a competition where judgments are based on predictions from unseen new data.

Though Stan does not yet have a robust workflow for this process, there are a couple of workarounds that can get the job done. In this post I’ll highlight three:

  1. Fit-and-predict: This approach involves specifying the predictive model in Stan’s generated quantitiesblock and re-estimating the model every time you need to make new predictions. It can be computationally intensive and slow, but it is robust and ensures predictions are always based on the most recently available data. The Stan development team recommends this approach.
  2. Predict outside of Stan: This approach involves estimating a model in Stan, then extracting posterior distributions of parameters and rebuilding the predictive structure in another language, such as R or Python.
  3. Predict with Stan: This approach involves writing another Stan program with only data and generated quantities blocks, wheredata block contains posterior distributions for the original program and independent variables.

Approaches

We’ll start by creating some fake data for this example by simulating data from a logistic regression model to estimate with Stan. Stan is not necessary for estimating this simple model, but the example if useful for illustrating the three approaches to making predictions with Stan. The data generating process is:

y ∼ Bernoulli(π);

π = inv_logit(α+β∗x)

We’ll want to estimate α and β so that we can make predictions for unseen y based on new data x as it becomes available.

library(dplyr)
library(ggplot2)
library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
# Create some fake data - logistic regression
set.seed(56)
N <- 500
alpha <- 1
beta <- 2
x <- rnorm(N)
prob <- 1/(1 + exp(-(alpha + beta*x)))
y <- rbinom(N, 1, prob)
# Distribution of y
table(y)
## y
## 0 1
## 171 329
# Split into training and testing
N_train <- N*0.8
N_test <- N*0.2
train_ind <- sample(c(1:N), size = N_train, replace = FALSE)
x_train <- x[train_ind]
x_test <- x[-train_ind]
y_train <- y[train_ind]
y_test <- y[-train_ind]

Fit-and-predict

The fit-and-predict approach uses Stan’s generated quantities block to make predictions from x_test in the same program that we used to estimate the relationship between x_train and y_train. From the Stan manual (http://mc-stan.org/users/documentation/index.html):

The generated quantities program block is rather different than the other blocks. Nothing in the generated quantities block affects the sampled parameter values. The block is executed only after a sample has been generated. Among the applications of posterior inference that can be coded in the generated quantities block are

  • forward sampling to generate simulated data for model testing,
  • generating predictions for new data,
  • calculating posterior event probabilities, including multiple comparisons, sign tests, etc.,
  • calculating posterior expectations,
  • transforming parameters for reporting,
  • applying full Bayesian decision theory,
  • calculating log likelihoods, deviances, etc. for model comparison.

This Stan program simultaneously fits the logistic regression model based on the training data and generates predictions for y_test based on x_test.

data {
int<lower = 1> N_train;
vector[N_train] x_train;
int<lower = 0, upper = 1> y_train[N_train];
int<lower = 1> N_test;
vector[N_test] x_test;
}
parameters {
real alpha;
real beta;
}
model {
y_train ~ bernoulli_logit(alpha + beta*x_train);
alpha ~ normal(5, 10);
beta ~ normal(5, 10);
}
generated quantities {
vector[N_test] y_test;
for(i in 1:N_test) {
y_test[i] = bernoulli_rng(inv_logit(alpha + beta*x_test[i]));
}
}

We’ll estimate this model (the code is saved in a separate file model_fit.stan) and check that it was able to recover the parameters α and β.

# Recover parameters with stan
fit <- stan(file = "model_fit.stan",
data = list(x_train, y_train, N_train,
x_test, N_test),
chains = 3, iter = 1000)
plot(fit, pars = c("alpha", "beta"))
traceplot(fit, pars = c("alpha", "beta"))
# Accuracy
ext_fit <- extract(fit)
mean(apply(ext_fit$y_test, 2, median) == y_test)
## [1] 0.75

The model converges and it the posterior distributions of the parameters are centered around their ‘true’ values. The accuracy of the model on new data is 0.75. This is a robust approach for making predictions for new data with Stan, but is impractical if predictions must be made frequently because it requires re-estimating the entire model every time new predictions need to be made.

Predict with R

Another option is to extract the posterior distributions of the parameters and use them to recreate the model in R. The code below extracts the posterior distributions (alpha_post and beta_post) and uses them in a prediction function gen_quant_r. The function simulates the data generating process with samples from the parameters’ posterior distributions. This approach doesn’t generate full posterior distributions for each prediction, but it is fast and easy to implement.

# Extract posteriod distributions
alpha_post <- ext_fit$alpha
beta_post <- ext_fit$beta
# Function for simulating y based on new x
gen_quant_r <- function(x) {
lin_comb <- sample(alpha_post, size = length(x)) + x*sample(beta_post, size = length(x))
prob <- 1/(1 + exp(-lin_comb))
out <- rbinom(length(x), 1, prob)
return(out)
}
# Run the function on x_test
set.seed(56)
y_pred_r <- gen_quant_r(x_test)
# Accuracy
mean(y_pred_r == y_test)
## [1] 0.75

The accuracy of this predictive approach is similar to the accuracy of the fit-and-predict approach.

Predict with Stan

The third approach is to write another Stan program to make predictions without refitting the old model. The parameter estimates from the original program become the data for the prediction program. The code below shows how this program might look. Note that the parameters and model blocks are empty because we are not estimating parameter distributions from a model.

data {
int N;
int N_samples;
vector[N] x_test;
vector[N_samples] alpha;
vector[N_samples] beta;
}
parameters {
}
model {
}
generated quantities {
matrix[N_samples, N] y_test;
for(n in 1:N) {
for(i in 1:N_samples) {
y_test[i, n] = bernoulli_rng(inv_logit(alpha[i] + beta[i]*x_test[n]));
}
}
}

When we run this program, we have to set the algorithm to fixed_param so Stan knows that it’s not estimating parameters. We have to go through a bit of effort to extract the distributions of the generated quantities. But once we do, we’ll have full posterior distributions of our predictions with similar accuracy to the two approaches outlined above.

pred <- stan(file = "model_pred.stan",
data = list(x_test = x_test, N = N_test,
N_samples = length(alpha_post),
alpha = alpha_post,
beta = beta_post),
chains = 1, iter = 100,
algorithm = "Fixed_param")

# Extract and format output
ext_pred <- extract(pred)
out_mat <- matrix(NA, nrow = dim(ext_pred$y_test)[2],
ncol = dim(ext_pred$y_test)[3])
for(i in 1:dim(ext_pred$y_test)[2]) {
for(j in 1:dim(ext_pred$y_test)[3]) {
out_mat[i, j] <- mean(ext_pred$y_test[, i, j])
}
}
# Accuracy
(apply(out_mat, 2, median) %>% round(0) == y_test) %>% mean()
## [1] 0.75

Conclusion

Stan is a powerful language for Bayesian inference with a robust mechanism for out-of-sample prediction in its generated quantities block. However, it is not easy to ‘re-use’ Stan models on new data as it becomes available, particularly if the data is streaming in and predictions must be made regularly. This post has outlined a few options for making live predictions with Stan. Each has strengths and weaknesses; which is most appropriate depends on the use-case.