Inference for Distributional Random Forests

Confidence intervals for a powerful nonparametric method

Jeffrey Näf
Towards Data Science

--

Features of (Distributional) Random Forests. In this article: The ability to provide uncertainty measures. Source: Author.

In a previous article, I extensively discussed the Distributional Random Forest method, a Random Forest-type algorithm that can nonparametrically estimate multivariate conditional distributions. This means that we are able to learn the whole distribution of a multivariate response Y given some covariates X nonparametrically, instead of “just” learning an aspect such as its conditional expectation. DRF does this by learning weights w_i(x) for the i=1,…,n training points that define the distribution and can be used to estimate a wide range of targets.

So far this method only produced a “point estimate” of the distribution (i.e. a point estimate for the n weights w_i(x)). While this is enough to predict the whole distribution of a response, it doesn’t give a way to make inference that considers the randomness of the data-generating mechanism. That is, even though this point estimate gets increasingly close to the truth for large sample sizes (under a list of assumptions), there is still uncertainty in its estimate for finite sample sizes. Luckily there is now a (provable) method to quantify this uncertainty as I lay out in this article. This is based on our new paper on arXiv.

The goal of this article is twofold: First, I want to discuss how to add uncertainty estimates to the DRF, based on our paper. The paper is quite theoretical, so I start with a few examples. The subsequent sections take a quick glance at these theoretical results, for those interested. I then explain how this can be used to get a (sampling-based) uncertainty measure for a wide range of targets. Second, I discuss the CoDiTE of [1] and a particularly interesting example of this concept, the conditional witness function. This function is a complicated object, yet, as we will see below, we can estimate it easily with DRF and can even provide asymptotic confidence bands, based on the concepts introduced in this article. An extensive real-data example of how this could be applied is given in this article.

Throughout we assume to have a d-variate i.i.d. sample Y_1, …, Y_n of variables of interest and a p-variate i.i.d. sample X_1,…,X_n of covariates. The goal is to estimate the conditional distribution of Y|X=x.

We will need the following packages and functions for this:

library(kernlab)
library(drf)
library(Matrix)
library(Hmisc)

source("CIdrf.R")

The functions in the file “CIdrf.R” can be found below.

In the following, all images, unless otherwise noted, are by the author.

Examples

We simulate from a simple example with d=1 and p=2:

 
set.seed(2)

n<-2000
beta1<-1
beta2<--1.8


# Model Simulation
X<-mvrnorm(n = n, mu=c(0,0), Sigma=matrix(c(1,0.7,0.7,1), nrow=2,ncol=2))
u<-rnorm(n=n, sd = sqrt(exp(X[,1])))
Y<- matrix(beta1*X[,1] + beta2*X[,2] + u, ncol=1)

Note that this is simply a heteroskedastic linear model, with the variance of the error term depending on the X_1 values. Of course, knowing the effect of X on Y is just linear, you would not use DRF, or any Random Forest for that matter, but directly go with linear regression. But for this purpose, it is convenient to know the truth. Since DRF’s job is to estimate a conditional distribution given X=x, we now fix x and estimate the conditional expectation and variance given X=x.

We choose a point that is right in the center of the X distribution, with lots of observations surrounding it. In general, one should be careful when using any Random Forest method for points on the border of the X observations.

# Choose an x that is not too far out
x<-matrix(c(1,1),ncol=2)

# Choose alpha for CIs
alpha<-0.05

Finally, we fit our DRF and obtain the weights w_i(x):

## Fit the new DRF framework
drf_fit <- drfCI(X=X, Y=Y, min.node.size = 2, splitting.rule='FourierMMD', num.features=10, B=100)

## predict weights
DRF = predictdrf(drf_fit, x=x)
weights <- DRF$weights[1,]

As explained below, the DRF object we built here not only contains the weights w_i(x), but also a sample of B weights that correspond to draws from the distribution of w_i(x). We can use these B draws to approximate the distribution of anything we want to estimate, as I illustrate now in two examples.

Example 1: Conditional Expectation

First, we simply do what most prediction methods do: We estimate the conditional expectation. With our new method, we also build a confidence interval around it.

# Estimate the conditional expectation at x:
condexpest<- sum(weights*Y)

# Use the distribution of weights, see below
distofcondexpest<-unlist(lapply(DRF$weightsb, function(wb) sum(wb[1,]*Y) ))

# Can either use the above directly to build confidence interval, or can use the normal approximation.
# We will use the latter
varest<-var(distofcondexpest-condexpest)

# build 95%-CI
lower<-condexpest - qnorm(1-alpha/2)*sqrt(varest)
upper<-condexpest + qnorm(1-alpha/2)*sqrt(varest)
c(lower, condexpest, upper)

(-1.00, -0.69 -0.37)

Importantly, though the estimated value is a bit off, this CI contains the truth, which is given as

Example 2: Conditional Variance

Assume now we would like to find the variance Var(Y|X=x) instead of the conditional mean. This is quite a challenging example for a nonparametric method that cannot make use of the linearity. The truth is given as

Using DRF, we can estimate this as follows:

# Estimate the conditional expectation at x:
condvarest<- sum(weights*Y^2) - condexpest^2

distofcondvarest<-unlist(lapply(DRF$weightsb, function(wb) {
sum(wb[1,]*Y^2) - sum(wb[1,]*Y)^2
} ))

# Can either use the above directly to build confidence interval, or can use the normal approximation.
# We will use the latter
varest<-var(distofcondvarest-condvarest)

# build 95%-CI
lower<-condvarest - qnorm(1-alpha/2)*sqrt(varest)
upper<-condvarest + qnorm(1-alpha/2)*sqrt(varest)

c(lower, condvarest, upper)

(1.89, 2.65, 3.42)

Thus the true parameter is contained in the CI, as we would hope, and in fact, we are quite close to the truth with our estimate!

We now study the theory underlying these examples, before we come to a third example in Causal Analysis.

Asymptotic Normality in the RKHS

In this and the next section, we briefly focus on the theoretical results derived in the paper. As explained above and in the article, DRF presents a distributional prediction at a test point x. That is, we obtain an estimate

of the conditional distribution of Y given X=x. This is just a typical way of writing an empirical measure, the magic lies in the weights w_i(x) — they can be used to easily obtain estimators of quantities of interest, or even to sample directly from the distribution.

To obtain this estimate, DRF actually estimates the conditional mean, but in a reproducing kernel Hilbert space (RKHS). An RKHS is defined through a kernel function k(y_1, y_2). With this kernel, we can map each observation Y_i into the Hilbert space, as k(Y_i, .). There is a myriad of methods using this extremely powerful tool, such as kernel ridge regression. The key point is that under some conditions, any distribution can be expressed as an element of this RKHS. It turns out that the true conditional distribution can be represented in the RKHS as the following expectation:

So this is just another way of expressing the conditional distribution of Y given X=x. We then try to estimate this element with DRF like this:

Again we are using the weights obtained from DRF, but now form a weighted sum with k(Y_i,.) instead of the Dirac measures above. We can map back and forth between the two estimates by writing either of the two. The reason this matters is that we can write the conditional distribution estimate as a weighted mean in the RKHS! Just as the original Random Forest estimates a mean in the real numbers (the conditional expectation of Y given X=x), DRF estimates a mean in the RKHS. Only with the latter, it turns out we also obtain an estimate of the conditional distribution.

The reason this is important for our story is that this weighted mean in the RKHS behaves quite similarly in some regards to a (weighted) mean in d dimensions. That is, we can study its consistency and asymptotic normality using the myriad of tools that are available for averages. This is quite remarkable, as all interesting RKHS will be infinite-dimensional. The first DRF paper already establishes consistency of the estimator in (1) in the RKHS. Our new paper now proves that, in addition,

where sigma_n is a standard deviation that goes to zero and Sigma_x is an operator that takes the place of a covariance matrix (again it all works quite similarly as in d-dimensional Euclidean space).

Obtaining the sampling distribution

Ok so, we have an asymptotic normality result in an infinite-dimensional space, what exactly does that mean? Well first, it means estimators derived from the DRF estimate that are “smooth’’ enough will also tend to be asymptotically normal. But this alone is still not useful, as we also need to have a variance estimate. Here a further result in our paper comes into play.

We leave away a lot of details here, but essentially we can use the following subsample scheme: Instead of just fitting say N trees to build our forest, we build B groups of L trees (such that N=B*L). Now for each group of trees or mini forests, we subsample at random about half of the data points and then fit the forest using only this subsample. Let’s call this subset of samples chosen S. For each drawn S we then get another DRF estimator in the Hilbert space denoted

only using the samples in S. Note that, as in bootstrapping, we now have two sources of randomness, even disregarding the randomness of the forest (in theory we assume B to be so large, as to make the randomness of the forest(s) negligible). One source from the data themselves and another artificial source of randomness, we introduce when choosing S at random. Crucially the randomness from S, given the data, is in our control — we can draw as many subsets S as we want. So the question is, what happens with our estimator in (2) if we only consider the randomness of S and fix the data? Remarkably, we can show that

This just means that if we fix the randomness of the data and only consider the randomness from S, the estimator (2) minus the estimator in (1) will converge in distribution to the same limit as the original estimator minus the truth! This is actually how bootstrap theory works: We have shown that something we can sample from, namely

converges to the same limit as what we cannot access, namely

So to make inference about the latter, we can use the former! This is actually the standard argument people make in bootstrap theory to justify why the bootstrap can be used to approximate the sampling distribution! That’s right, even bootstrap, a technique that people often use in small samples, only really makes sense (theoretically) in a large sample regime.

Let’s use this now.

What does this actually mean?

We now show what this means in practice. In the following, we define two new functions derived from the drf function of the CRAN package drf.

## Functions in CIdrf.R that is loaded above ##

drfCI <- function(X, Y, B, sampling = "binomial",...) {

### Function that uses DRF with subsampling to obtain confidence regions as
### as described in https://arxiv.org/pdf/2302.05761.pdf
### X: Matrix of predictors
### Y: Matrix of variables of interest
### B: Number of half-samples/mini-forests


n <- dim(X)[1]

# compute point estimator and DRF per halfsample S
# weightsb: B times n matrix of weights
DRFlist <- lapply(seq_len(B), function(b) {

# half-sample index
indexb <- if (sampling == "binomial") {
seq_len(n)[as.logical(rbinom(n, size = 1, prob = 0.5))]
} else {
sample(seq_len(n), floor(n / 2), replace = FALSE)
}

## Using refitting DRF on S
DRFb <-
drf(X = X[indexb, , drop = F], Y = Y[indexb, , drop = F],
ci.group.size = 1, ...)


return(list(DRF = DRFb, indices = indexb))
})

return(list(DRFlist = DRFlist, X = X, Y = Y) )
}


predictdrf<- function(DRF, x, ...) {

### Function to predict from DRF with Confidence Bands
### DRF: DRF object
### x: Testpoint

ntest <- nrow(x)
n <- nrow(DRF$Y)

## extract the weights w^S(x)
weightsb <- lapply(DRF$DRFlist, function(l) {

weightsbfinal <- Matrix(0, nrow = ntest, ncol = n , sparse = TRUE)

weightsbfinal[, l$indices] <- predict(l$DRF, x)$weights

return(weightsbfinal)
})


## obtain the overall weights w
weights<- Reduce("+", weightsb) / length(weightsb)


return(list(weights = weights, weightsb = weightsb ))
}



Witdrf<- function(DRF, x, groupingvar, alpha=0.05, ...){

### Function to calculate the conditional witness function with
### confidence bands from DRF
### DRF: DRF object
### x: Testpoint

if (is.null(dim(x)) ){

stop("x needs to have dim(x) > 0")
}

ntest <- nrow(x)
n <- nrow(DRF$Y)
coln<-colnames(DRF$Y)


## Collect w^S
weightsb <- lapply(DRF$DRFlist, function(l) {

weightsbfinal <- Matrix(0, nrow = ntest, ncol = n , sparse = TRUE)

weightsbfinal[, l$indices] <- predict(l$DRF, x)$weights

return(weightsbfinal)
})

## Obtain w
weightsall <- Reduce("+", weightsb) / length(weightsb)

#weightsall0<-weightsall[, DRF$Y[, groupingvar]==0, drop=F]
#weightsall1<-weightsall[,DRF$Y[, groupingvar]==1, drop=F]


# Get the weights of the respective classes (need to standardize by propensity!)
weightsall0<-weightsall*(DRF$Y[, groupingvar]==0)/sum(weightsall*(DRF$Y[, groupingvar]==0))
weightsall1<-weightsall*(DRF$Y[, groupingvar]==1)/sum(weightsall*(DRF$Y[, groupingvar]==1))


bandwidth_Y <- drf:::medianHeuristic(DRF$Y)
k_Y <- rbfdot(sigma = bandwidth_Y)

K<-kernelMatrix(k_Y, DRF$Y[,coln[coln!=groupingvar]], y = DRF$Y[,coln[coln!=groupingvar]])


nulldist <- sapply(weightsb, function(wb){
# iterate over class 1

wb0<-wb*(DRF$Y[, groupingvar]==0)/sum(wb*(DRF$Y[, groupingvar]==0))
wb1<-wb*(DRF$Y[, groupingvar]==1)/sum(wb*(DRF$Y[, groupingvar]==1))


diag( ( wb0-weightsall0 - (wb1-weightsall1) )%*%K%*%t( wb0-weightsall0 - (wb1-weightsall1) ) )


})

# Choose the right quantile
c<-quantile(nulldist, 1-alpha)


return(list(c=c, k_Y=k_Y, Y=DRF$Y[,coln[coln!=groupingvar]], nulldist=nulldist, weightsall0=weightsall0, weightsall1=weightsall1))



}



So from our method, we not only get the point estimate in form of weights w_i(x), but a sample of B weights, each representing an independent draw from the distribution of the estimator of the conditional distribution (that sounds more confusing than it should be, please keep the examples in mind). This just means we are not only having an estimator, but also an approximation to its distribution!

I now turn to a more interesting example of something we can only do with DRF (as far as I know).

Causal Analysis Example: Witness Function

Let’s assume we have two sets of observations, say group W=1 and group W=0 and we want to find the causal relationship between the group belonging and a variable Y. In the example of this article, the two groups would be male and female and Y would be the hourly wage. In addition, we have confounders X, which we assume affect both W and Y. We assume here that X really includes all relevant confounders. This is a BIG assumption. Formally, we assume unconfoundedness:

and overlap:

Often people then compare the conditional expectation between the two groups:

This is the Conditional Average Treatment Effect (CATE) at x. This is a natural first starting point, but in a recent paper ([1]), the CoDiTE was introduced as a generalization of this idea. Instead of just looking at the difference in expected values the CoDiTE proposes to look at differences in other quantities as well. A particularly interesting example of this idea is the conditional witness function: For both groups, we take as above

So we consider the representation of the two conditional distributions in the RKHS. In addition to being representations of the conditional distributions, these quantities are also real-valued functions: For j=0,1,

The function that gives the difference between those two quantities,

is called the conditional witness function.

Why is this function interesting? It turns out that this function shows how the two densities behave in relation to each other: For values of y for which the function is negative, the conditional density of class 1 at y is smaller than the conditional density of 0. Similarly, if the function is positive at y, it means the density of 1 is higher at y than the conditional density of 0 (whereby “conditional” always refers to conditioning on X=x). Crucially, this can be done without having to estimate the densities, which is hard, especially for multivariate Y.

Finally, we can provide uniform confidence bands for our estimated conditional witness functions, by using the B samples from above. I do not go into details here, but these are essentially the analog to the confidence intervals for the conditional mean we used above. Crucially, these bands should be valid uniformly over the function values y, for one specific x.

Let’s illustrate this with an example: We simulate the following data-generating process:

That is, X_1, X_2 are independently uniformly distributed on (0,1), W is either 0 or 1, with a probability depending on X_2 and Y is a function of W and X_1. This is a really hard problem; not only does X influence the probability of belonging to class 1 (i.e. the propensity), it also changes the treatment effect of W on Y. In fact, a small calculation shows that the CATE is given as:

(1 - 0.2)*X_1 - (0 - 0.2)*X_1 = X_1.

Graph corresponding to the data-generating process
set.seed(2)

n<-4000
p<-2



X<-matrix(runif(n*p), ncol=2)
W<-rbinom(n,size=1, prob= exp(-X[,2])/(1+exp(-X[,2])))

Y<-(W-0.2)*X[,1] + rnorm(n)
Y<-matrix(Y,ncol=1)

We now randomly choose a test point x and use the following code to estimate the witness function plus confidence band:


x<-matrix(runif(1*p), ncol=2)
Yall<-cbind(Y,W)
## For the current version of the Witdrf function, we need to give
## colnames to Yall
colnames(Yall) <- c("Y", "W")

## Fit the new DRF framework
drf_fit <- drfCI(X=X, Y=Yall, min.node.size = 5, splitting.rule='FourierMMD', num.features=10, B=100)

Witobj<-Witdrf(drf_fit, x=x, groupingvar="W", alpha=0.05)

hatmun<-function(y,Witobj){

c<-Witobj$c
k_Y<-Witobj$k_Y
Y<-Witobj$Y
weightsall1<-Witobj$weightsall1
weightsall0<-Witobj$weightsall0
Ky=t(kernelMatrix(k_Y, Y , y = y))

#K1y <- t(kernelMatrix(k_Y, DRF$Y[DRF$Y[, groupingvar]==1,coln[coln!=groupingvar]], y = y))
#K0y <- t(kernelMatrix(k_Y, DRF$Y[DRF$Y[, groupingvar]==0,coln[coln!=groupingvar]], y = y))
out<-list()
out$val <- tcrossprod(Ky, weightsall1 ) - tcrossprod(Ky, weightsall0 )
out$upper<- out$val+sqrt(c)
out$lower<- out$val-sqrt(c)

return( out )



}

all<-hatmun(sort(Witobj$Y),Witobj)

plot(sort(Witobj$Y),all$val , type="l", col="darkblue", lwd=2, ylim=c(min(all$lower), max(all$upper)),
xlab="y", ylab="witness function")
lines(sort(Witobj$Y),all$upper , type="l", col="darkgreen", lwd=2 )
lines(sort(Witobj$Y),all$lower , type="l", col="darkgreen", lwd=2 )
abline(h=0)

We can read from this plot that:

(1) The conditional density of group 1 is lower than the density of group 0 for values of y between -3 and 0.3. Moreover, this difference gets larger the larger y is until about y = -1, after which point the difference in densities starts to decrease again until the two densities are the same at around 0.3.

(2) Symmetrically, the density of group 1 is higher than the density of group 0 for values of y between 0.3 and 3 and this difference gets larger until it reaches a maximum at about y = 1.5. After this point, the difference decreases until it is almost zero again at y = 3.

(3) The difference between the two densities is statistically significant at the 95% percent level, as can be seen from the fact that for y approximately between -1.5 and -0.5 and between 1 and 2, the asymptotic confidence bands do not include the zero line.

Let’s check (1) and (2) for the simulated true conditional densities. That is, we simulate the truth a great number of times:

# Simulate truth for a large number of samples ntest
ntest<-10000
Xtest<-matrix(runif(ntest*p), ncol=2)

Y1<-(1-0.2)*Xtest[,1] + rnorm(ntest)
Y0<-(0-0.2)*Xtest[,1] + rnorm(ntest)


## Plot the test data without adjustment
plotdf = data.frame(Y=c(Y1,Y0), W=c(rep(1,ntest),rep(0,ntest) ))
plotdf$weight=1
plotdf$plotweight[plotdf$W==0] = plotdf$weight[plotdf$W==0]/sum(plotdf$weight[plotdf$W==0])
plotdf$plotweight[plotdf$W==1] = plotdf$weight[plotdf$W==1]/sum(plotdf$weight[plotdf$W==1])

plotdf$W <- factor(plotdf$W)

#plot pooled data
ggplot(plotdf, aes(Y)) +
geom_density(adjust=2.5, alpha = 0.3, show.legend=TRUE, aes(fill=W, weight=plotweight)) +
theme_light()+
scale_fill_discrete(name = "Group", labels = c('0', "1"))+
theme(legend.position = c(0.83, 0.66),
legend.text=element_text(size=18),
legend.title=element_text(size=20),
legend.background = element_rect(fill=alpha('white', 0.5)),
axis.text.x = element_text(size=14),
axis.text.y = element_text(size=14),
axis.title.x = element_text(size=19),
axis.title.y = element_text(size=19))+
labs(x='y')

This leads to:

It is a bit hard to compare visually, but we see that the two densities behave quite close to what the witness function above predicted. In particular, we see that the densities are about the same around 0.3 and the difference in densities appears to be maximal approximately around -1 and 1.5. Thus both points (1) and (2) can be seen in the actual densities!

Moreover, to get (3) into context, a repeated simulation in the paper shows how the estimated witness function tends to look when no effect is visible:

Simulation of a 1000 witness functions in a similar setting as described here. In blue are the 1000 estimated witness functions, while in grey one can see the corresponding confidence bands. Taken from our paper on arXiv. There is no effect in this example, and 99% of CIs do not contain the zero line.

A real data example in Causal Inference is given in this article.

Conclusion

In this article, I discussed the new inferential tools available for Distributional Random Forests. I also looked at an important application of these new capabilities; estimating the conditional witness function with uniform confidence bands.

However, I also want to offer a few words of warning:

  1. The results are only valid for a given test point x
  2. The results are only valid asymptotically
  3. The current code is much much slower than it could be

The first point is actually not so bad, in simulations, the asymptotic normality often also holds over a range of x. Just be careful with test points that are close to the boundary of your sample! Intuitively, DRF (and all other nearest neighborhood methods) need many sample points around the test point x to estimate the response for x. So if the covariates X in your training set are standard normal, with most points between -2 and 2, then predicting an x in [-1,1] should be no problem. But if your x reaches -2 or 2, performance starts to deteriorate fast.

Random Forests (and nearest neighbourhood methods in general) are not good at predicting for points that only have a few neighbours in the training set, such as points at the boundary of the support of X.

The second point is also quite important. Asymptotic results have fallen somewhat out of fashion in contemporary research, in favor of finite sample results that in turn require assumptions such as “sub-Gaussianity”. I personally find this a bit ridiculous, asymptotic results provide extremely powerful approximations in complicated settings like these. And in fact, this approximation is pretty accurate for many targets for more than 1000 or 2000 data points (maybe you have 92% coverage instead of 95% for your conditional mean/quantile). However, the witness function we introduced is a complicated object, and thus the more data points you have to estimate the uncertainty bands around it, the better!

Finally point three is just a shortcoming on our side: While DRF itself is efficiently written in C, estimating the uncertainty with S is entirely based in R for the moment. Fixing this would provide an extreme speed-up to the code. We hope to be able to fix this in the future.

Citations

[1] Junhyung Park, Uri Shalit, Bernhard Schölkopf, and Krikamol Muandet. “Conditional distributional treatment effect with kernel conditional mean embeddings and U-statistic regression.” In Proceedings of 38th International Conference on Machine Learning (ICML) , volume 139, pages 8401–8412. PMLR, July 2021.

--

--

I am a researcher with a PhD in statistics and always happy to study and share research, data-science skills, deep math and life-changing views.