33.4 Tuning decision trees

In the previous section, we fit a model with the default settings. Can we improve performance by changing these? Let’s find out! But first, what might we change?

33.4.1 Decision tree hyperparamters

Decision trees have three hyperparamters as shown below. These are standard hyperparameters and are implemented in {rpart}, the engine we used for fitting decision tree models in the previous section. Alternative implementations may have slightly different hyperparameters (see the documentation for parsnip::decision_tree() details on other engines).

Hyperparameter Function Description
Cost Complexity cost_complexity() A regularization term that introduces a penalty to the objective function and controls the amount of pruning.
Tree depth tree_depth() The maximum depth the tree should be grown
Minimum \(n\) min_n() The minimum number of observations that must be present in a terminal node.

Perhaps the most important hyperparameter is the cost complexity parameter, which is a regularization parameter that penalizes the objective function by model complexity. In other words, the deeper the tree, the higher the penalty. The cost complexity parameter is typically denoted \(\alpha\), and penalizes the sum of squared errors by

\[ SSE + \alpha |T| \]

where \(T\) is the number of terminal nodes. Any value can be use for alpha, but typical values are less that 0.1. The cost complexity helps control model complexity through a process called pruning, in which a decision tree is first grown very deep, and then pruned back to a smaller subtree. The tree is initially grown just like any standard decision tree, but it is pruned to the subtree that optimizes the penalized objective function above. Different values of \(\alpha\) will, of course, lead to different subtrees. The best values are typically determined via grid search via cross validation. Larger cost complexity values will result in smaller trees, while smaller values will result in more complex trees.

Note that, similar to penalized regression, if you are using cost complexity to prune a tree it is important that all features are placed on the same scale (normalized) so the scale of the feature doesn’t influence the penalty.

The tree depth and minimum \(n\) are a more straightforward methods to control model complexity. The tree depth is just the maximum depth to which a tree can be grown (maximum number of splits). The minimum \(n\) controls the splitting criteria. A node cannot be split further once the \(n\) within that node is below the minimum specified.

33.4.3 Finalizing our model fit

Generally before moving to the our final fit we’d probably want to do a bit more work with the model to make sure we were confident it was really the best model we could produce. I’d be particularly interested at looking at minimum \(n\) around the 0.001 cost complexity parameter (given that the overall optimum in our original gridsearch had this value with a minimum \(n\) of 11). But for illustration purposes, let’s assume we’re ready to go (and really, decision trees don’t have a lot more tuning we can do with them, at least using the rpart engine).

First, let’s finalize our model using the best min_n we found from our grid search. We’ll use finalize_model along with select_best (rather than show_best) to set the final model parameters.

best_params <- select_best(dt_tune_fit2, metric = "roc_auc")
final_mod <- finalize_model(dt_tune2, best_params)
## Decision Tree Model Specification (classification)
## Main Arguments:
##   cost_complexity = 1e-10
##   min_n = 37
## Computational engine: rpart

Note that the min_n is now set. If we had done any tuning with our recipe we could follow a similar process.

Next, we’re going to use our original initial_split() object to, with a single function fit our model to our full training data (rather than by fold) and make predictions on the test set, and evalute the performance of the model. We do this all throught he last_fit function.

dt_finalized <- last_fit(final_mod,
                         preprocessor = rec,
                         split = splt)
## # Resampling results
## # Monte Carlo cross-validation (0.75/0.25) with 1 resamples  
## # A tibble: 1 x 6
##   splits        id          .metrics      .notes      .predictions     .workflow
##   <list>        <chr>       <list>        <list>      <list>           <list>   
## 1 <split [65K/… train/test… <tibble [2 ×… <tibble [0… <tibble [21,682… <workflo…

What we get output doesn’t look terrifically helpful, but it is. It’s basically everything we need. For example, let’s look at our metrics.

## # A tibble: 2 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy multiclass     0.543
## 2 roc_auc  hand_till      0.735

unsurprisingly, our AUC is a bit lower for our test set. What if we want our predictions?

predictions <- dt_finalized$.predictions[[1]]
## # A tibble: 21,682 x 7
##    .pred_0 .pred_1 .pred_2 .pred_3  .row .pred_class accuracy_group
##      <dbl>   <dbl>   <dbl>   <dbl> <int> <fct>       <ord>         
##  1  0.686   0.314  0        0          1 0           0             
##  2  0.706   0.167  0.0816   0.0461     6 0           0             
##  3  0.154   0.0769 0.462    0.308      7 2           3             
##  4  0.0143  0.0571 0.1      0.829      9 3           3             
##  5  0.12    0.09   0.06     0.73      11 3           3             
##  6  0       0.737  0.175    0.0877    18 1           1             
##  7  0.725   0.266  0.00917  0         28 0           0             
##  8  0.511   0.216  0.144    0.129     32 0           2             
##  9  0       1      0        0         36 1           1             
## 10  0.286   0.476  0        0.238     47 1           1             
## # … with 21,672 more rows

This shows us the predicted probability that each case would be in each class, along with a “hard” prediction into a class, and their observed class (accuracy_group). We can use this for further visualizations and to better understand how our model makes predictions and where it is wrong. For example, let’s look at a quick heat map of the predicted class versus the observed.

counts <- predictions %>% 
  count(.pred_class, accuracy_group) %>% 
  drop_na() %>% 
  group_by(accuracy_group) %>% 
  mutate(prop = n/sum(n)) 

ggplot(counts, aes(.pred_class, accuracy_group)) +
  geom_tile(aes(fill = prop)) +
  geom_label(aes(label = round(prop, 2))) +
    palette = "Blue-Red2",
    mid = .25,
    rev = TRUE)

Notice that I’ve omitted NA’s here, which is less than ideal, because we have a lot of them. This is mostly because the original data themselves have so much missing data on the outcome, so it’s hard to know how well we’re actually doing with those cases. Instead, we’re just evaluating our model with the cases for which we actually have an observed outcome. The plot above shows the proportion by row. In other words, each row sums to 1.0.

We can fairly quickly see that our model has some fairly significant issues. We are doing okay predicting classes for 0 and 3 (about 75% correct, in each case) but we’re not a whole lot better than random chance leve (which would be 0.25-ish in each cell) when predicting Classes 1 and 2. It’s fairly concerning that 32% of cases that were actually Class 2 were predicted to be Class 0. We would likely want to conduct a post-mortem with these cases to see if we could understand why our model was failing in this particular direction.

Decision trees, generally, are easily interpretable and easy to communicate with stakeholders. They also make no assumptions about the data, and can be applied in a large number of situations. Unfortunately, they often suffer from rapid overfitting to the data leading to poor generalizations to unseen data. In the next chapter, we’ll build on decision trees to talk about ensemble methods, where we use multiple trees to make a single prediction.