randomforest in rpart — part 2: a simple decision tree
So I have something to compare my rpart random forest to, I am going to build a simple decision tree using the rpart package. After all, if I can’t beat a single rpart tree then all my work was pointless!
set.seed(20) # predictable randomness
# split the data into training and test data (arbitrarily!)
training_data <- data[1:52000,]
test_data <- data[52001:nrow(data),]
# set the target variable
targ <- "is.premium"
# set the predictors
preds <- c("carat", "depth", "table", "price", "x", "y", "z", "clarity", "color")
# build a simple rpart decision tree using the default settings
dtree <- rpart(formula = data[,targ] ~ ., data = data[,preds])
Using the rpart.plot package rpart.plot() function we can visualise the tree
It looks like the tree only uses two of our predictors - table, and depth.
To give you an idea of how tuneable rpart is, the following code adjusts one control parameter - the complexity parameter (cp). By default rpart sets it to 0.01, so i’ve set it to 0.005. I’ve always understood cp as telling the tree algorithm to stop when the newly selected split can’t decrease the relative error by at least the value cp is set to. So, the lower the cp the bigger the tree. A downside of setting a tiny cp is the compute cost of building a bigger tree. In some ways maybe cp could be looked on as pre-pruning pruning.
dtree.cp <- rpart(formula = data[,targ] ~ ., data = data[,preds], control = rpart.control(cp = 0.005))
The first four splits are the same - they are the ‘best’ based on the criteria rpart is using. After this, the tree reuses table and two more predictors, x and y.
You can look at many things to assess a models performance, but the most important thing is to have something to measure performance against otherwise it’s pretty hard to know if you have built a useful model. This could be an existing model, some metric set by a client, or industry standards.
To assess the performance of dtree I’m going run the model over the test data and do two fairly common things things:
- Plot the percentage of the caseload assessed against the percentage of premiums classified
- Create a confusion matrix
First, run the model with the test data:
dtree_preds <- predict(dtree, test_data)
For the first plot, I’m using some hacked together dplyr and ggplot (for learning) to sort the data by score and then cumulatively calculate the percentages I need for plotting. There are a few packages that do this for you e.g. riskchart() in Rattle.
# stick the actual values to the predictions
outcomes <- as.data.frame(cbind(test_data$is.premium, ifelse(dtree_preds >= 0.5, 1, 0), dtree_preds))
# rename the columns
names(outcomes) <- c('actual', 'predicted_r', 'predicted')
# order the dataframe by predicted
outcomes <- outcomes %>% arrange(., desc(predicted)) %>%
# a hacked together number of rows so far/number of rows overall to get percentage of caseload
mutate(., percent_caseload = cumsum(ifelse(predicted_r >= 0, 1, 1)/nrow(.))) %>%
# what percentage of the targets have we captured so far
mutate(., percent_targets = cumsum(actual)/sum(actual))
# take a peek at the dataframe
# plot the line we have created
ggplot(data = outcomes, aes(percent_caseload, percent_targets)) + geom_line(colour = 'darkgreen')
In the scope of the example I’m playing with this really doesn’t mean much and there are only three unique scores (0.6.., 0.05…, and 0.00) so the arbitrary ordering within these groups influences how the chart displays. That being said it isn’t a terrible model. The curve is relatively steep at the beginning and peters out as you look at more of the caseload.
In the real world if you had X analysts who could assess Y% of overall cases you could use this to show which model of a set of models performs best at Y. In this case if we could only look at 25% of all transactions we would get around 60% of the risk.
To create the confusion matrix I’m using the confusionMatrix() function from the caret package. To generate the confusion matrix we have to convert dtree_preds to 0/1 so it matches the actual outcomes.
confusionMatrix(data = ifelse(dtree_preds >= 0.5, 1, 0), reference = test_data$is.premium)
It looks like dtree does a decent job of locating diamonds that are premium (true positives) but misclassifies a number of non premium diamonds as premium (false positives).
And this is the end of part two. Part three will involve using a for loop to call rpart() multiple times building a list of trees/a forest.