## 33.4 Tuning decision trees

In the previous section, we fit a model with the default settings. Can we improve performance by changing these? Let’s find out! But first, what might we change?

### 33.4.1 Decision tree hyperparamters

Decision trees have three hyperparamters as shown below. These are standard hyperparameters and are implemented in {rpart}, the engine we used for fitting decision tree models in the previous section. Alternative implementations may have slightly different hyperparameters (see the documentation for `parsnip::decision_tree()`

details on other engines).

Hyperparameter | Function | Description |
---|---|---|

Cost Complexity |
`cost_complexity()` |
A regularization term that introduces a penalty to the objective function and controls the amount of pruning. |

Tree depth |
`tree_depth()` |
The maximum depth the tree should be grown |

Minimum \(n\) |
`min_n()` |
The minimum number of observations that must be present in a terminal node. |

Perhaps the most important hyperparameter is the cost complexity parameter, which is a regularization parameter that penalizes the objective function by model complexity. In other words, the deeper the tree, the higher the penalty. The cost complexity parameter is typically denoted \(\alpha\), and penalizes the sum of squared errors by

\[ SSE + \alpha |T| \]

where \(T\) is the number of terminal nodes. Any value can be use for alpha, but typical values are less that 0.1. The cost complexity helps control model complexity through a process called **pruning**, in which a decision tree is first grown very deep, and then *pruned* back to a smaller subtree. The tree is initially grown just like any standard decision tree, but it is pruned to the subtree that optimizes the penalized objective function above. Different values of \(\alpha\) will, of course, lead to different subtrees. The best values are typically determined via grid search via cross validation. Larger cost complexity values will result in smaller trees, while smaller values will result in more complex trees.

Note that, similar to penalized regression, if you are using cost complexity to prune a tree it is important that all features are placed on the same scale (normalized) so the scale of the feature doesn’t influence the penalty.

The tree depth and minimum \(n\) are a more straightforward methods to control model complexity. The tree depth is just the maximum depth to which a tree can be grown (maximum number of splits). The minimum \(n\) controls the splitting criteria. A node cannot be split further once the \(n\) within that node is below the minimum specified.

### 33.4.2 Conducting the grid search

Let’s tune our model using `cost_complexity()`

and `min_n()`

and let the depth be controlled by these parameters.

First, we’ll modify our model from before to set the parameters to tune.

Next, we’ll set up our grid. We can use helper functions from the {dials} package (part of *tidymodels*) to help us come up with reasonable values. Let’s use a regular grid with 10 possible values for cost complexity and 5 possible values for the minimum \(n\).

Let’s see what the space we’rd evaluating actually looks like for our hyperparamters.

As we can see, there’s a big gap in cost complexity, so we’ll want to be careful when we investigate the optimal values there.

Now let’s actually conduct the search. I have again included timing here so you can see how long it took for me to run on my local computer (which is a decent amount of time).

The below took about 20 minutes to run on a local computer.

First let’s look at our results by hyperparameter. We can use `collect_metrics`

to get a summary (mean) of the metrics we set (which we didn’t, so we’ll get the defaults, which in this case are roc_auc and accuracy) across folds. You can optionally get te results by fold (no summarizing) by specifying `summarize = FALSE`

.

```
## # A tibble: 100 x 8
## cost_complexity min_n .metric .estimator mean n std_err .config
## <dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.0000000001 2 accuracy multiclass 0.466 10 0.00196 Model01
## 2 0.0000000001 2 roc_auc hand_till 0.607 10 0.00126 Model01
## 3 0.000000001 2 accuracy multiclass 0.466 10 0.00196 Model02
## 4 0.000000001 2 roc_auc hand_till 0.607 10 0.00126 Model02
## 5 0.00000001 2 accuracy multiclass 0.466 10 0.00196 Model03
## 6 0.00000001 2 roc_auc hand_till 0.607 10 0.00126 Model03
## 7 0.0000001 2 accuracy multiclass 0.466 10 0.00196 Model04
## 8 0.0000001 2 roc_auc hand_till 0.607 10 0.00126 Model04
## 9 0.000001 2 accuracy multiclass 0.466 10 0.00196 Model05
## 10 0.000001 2 roc_auc hand_till 0.607 10 0.00126 Model05
## # … with 90 more rows
```

To get an idea of how things are performing, let’s look at our two hyperparameters with `roc_auc`

as our metric. This is a metric we want to maximize, with a value of 1.0 indicating perfect predictions.

```
dt_tune_metrics %>%
filter(.metric == "roc_auc") %>%
ggplot(aes(cost_complexity, mean)) +
geom_point() +
facet_wrap(~min_n)
```

So generally it’s looking like lower values of of cost complexity are leading to better performing models, and the minimum sample size for a terminal node is looking best at 21 or 30. Let’s look at our best models in a more tabular form. This time we’ll use `show_best`

to show our best hyperparameter combinations.

```
## # A tibble: 5 x 8
## cost_complexity min_n .metric .estimator mean n std_err .config
## <dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.0001 40 roc_auc hand_till 0.741 10 0.00167 Model47
## 2 0.0001 2 roc_auc hand_till 0.740 10 0.00185 Model07
## 3 0.0001 30 roc_auc hand_till 0.740 10 0.00164 Model37
## 4 0.0001 11 roc_auc hand_till 0.740 10 0.00186 Model17
## 5 0.0001 21 roc_auc hand_till 0.740 10 0.00169 Model27
```

And we see a little bit different picture here. Our best model has a minimum \(n\) of 11 with a relatively higher cost complexity. But the amount this model is “better” is trivial and could easily be due to sampling variability. All of the rest of the best models have the same minimum \(n\), with the cost complexity playing essentially no role. This may lead us to consider *not* pruning by the cost complexity parameter at all. Let’s use all the results again to look at the minimum \(n\) a little closer. We’ll filter for a very low cost complexity.

```
dt_tune_metrics %>%
filter(.metric == "roc_auc" &
cost_complexity == 0.0000000001) %>%
arrange(desc(mean))
```

```
## # A tibble: 5 x 8
## cost_complexity min_n .metric .estimator mean n std_err .config
## <dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.0000000001 40 roc_auc hand_till 0.734 10 0.00181 Model41
## 2 0.0000000001 30 roc_auc hand_till 0.727 10 0.00168 Model31
## 3 0.0000000001 21 roc_auc hand_till 0.717 10 0.00105 Model21
## 4 0.0000000001 11 roc_auc hand_till 0.688 10 0.00132 Model11
## 5 0.0000000001 2 roc_auc hand_till 0.607 10 0.00126 Model01
```

Unsurprisingly, 30 is looking best for our minimum \(n\). To be sure we’ve got this right though, let’s set cost_complexity and tune *just* on the minimum \(n\). We know that values of 21 and 40 are both worst than 30, but let’s see if there’s any more room for optimization around there.

This shouldn’t take quite as long as our previous grid search, but on my local computer it still took about five and half minutes.

```
grid_min_n <- tibble(min_n = 23:37)
dt_tune2 <- dt_tune %>%
set_args(cost_complexity = 0.0000000001)
dt_tune_fit2 <- tune_grid(
dt_tune2,
preprocessor = rec,
resamples = cv,
grid = grid_min_n
)
```

Let’s look at our best metrics now and see if we’ve made any improvements.

```
## # A tibble: 5 x 7
## min_n .metric .estimator mean n std_err .config
## <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 37 roc_auc hand_till 0.732 10 0.00164 Model15
## 2 36 roc_auc hand_till 0.732 10 0.00172 Model14
## 3 35 roc_auc hand_till 0.731 10 0.00162 Model13
## 4 34 roc_auc hand_till 0.730 10 0.00170 Model12
## 5 33 roc_auc hand_till 0.730 10 0.00166 Model11
```

And look at that! It’s a (very) marginal improvement, but we have optimized our model a bit more.

### 33.4.3 Finalizing our model fit

Generally before moving to the our final fit we’d probably want to do a bit more work with the model to make sure we were confident it was really the best model we could produce. I’d be particularly interested at looking at minimum \(n\) around the 0.001 cost complexity parameter (given that the overall optimum in our original gridsearch had this value with a minimum \(n\) of 11). But for illustration purposes, let’s assume we’re ready to go (and really, decision trees don’t have a lot more tuning we can do with them, at least using the *rpart* engine).

First, let’s finalize our model using the best `min_n`

we found from our grid search. We’ll use `finalize_model`

along with `select_best`

(rather than `show_best`

) to set the final model parameters.

```
best_params <- select_best(dt_tune_fit2, metric = "roc_auc")
final_mod <- finalize_model(dt_tune2, best_params)
final_mod
```

```
## Decision Tree Model Specification (classification)
##
## Main Arguments:
## cost_complexity = 1e-10
## min_n = 37
##
## Computational engine: rpart
```

Note that the min_n is now set. If we had done any tuning with our recipe we could follow a similar process.

Next, we’re going to use our original `initial_split()`

object to, with a single function fit our model to our full training data (rather than by fold) and make predictions on the test set, and evalute the performance of the model. We do this all throught he `last_fit`

function.

```
## # Resampling results
## # Monte Carlo cross-validation (0.75/0.25) with 1 resamples
## # A tibble: 1 x 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [65K/… train/test… <tibble [2 ×… <tibble [0… <tibble [21,682… <workflo…
```

What we get output doesn’t look terrifically helpful, but it is. It’s basically everything we need. For example, let’s look at our metrics.

```
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.543
## 2 roc_auc hand_till 0.735
```

unsurprisingly, our AUC is a bit lower for our test set. What if we want our predictions?

```
## # A tibble: 21,682 x 7
## .pred_0 .pred_1 .pred_2 .pred_3 .row .pred_class accuracy_group
## <dbl> <dbl> <dbl> <dbl> <int> <fct> <ord>
## 1 0.686 0.314 0 0 1 0 0
## 2 0.706 0.167 0.0816 0.0461 6 0 0
## 3 0.154 0.0769 0.462 0.308 7 2 3
## 4 0.0143 0.0571 0.1 0.829 9 3 3
## 5 0.12 0.09 0.06 0.73 11 3 3
## 6 0 0.737 0.175 0.0877 18 1 1
## 7 0.725 0.266 0.00917 0 28 0 0
## 8 0.511 0.216 0.144 0.129 32 0 2
## 9 0 1 0 0 36 1 1
## 10 0.286 0.476 0 0.238 47 1 1
## # … with 21,672 more rows
```

This shows us the predicted probability that each case would be in each class, along with a “hard” prediction into a class, and their observed class (`accuracy_group`

). We can use this for further visualizations and to better understand how our model makes predictions and where it is wrong. For example, let’s look at a quick heat map of the predicted class versus the observed.

```
counts <- predictions %>%
count(.pred_class, accuracy_group) %>%
drop_na() %>%
group_by(accuracy_group) %>%
mutate(prop = n/sum(n))
ggplot(counts, aes(.pred_class, accuracy_group)) +
geom_tile(aes(fill = prop)) +
geom_label(aes(label = round(prop, 2))) +
colorspace::scale_fill_continuous_diverging(
palette = "Blue-Red2",
mid = .25,
rev = TRUE)
```

Notice that I’ve omitted NA’s here, which is less than ideal, because we have a lot of them. This is mostly because the original data themselves have so much missing data on the outcome, so it’s hard to know how well we’re actually doing with those cases. Instead, we’re just evaluating our model with the cases for which we actually have an observed outcome. The plot above shows the proportion *by row*. In other words, each row sums to 1.0.

We can fairly quickly see that our model has some fairly significant issues. We are doing okay predicting classes for 0 and 3 (about 75% correct, in each case) but we’re not a whole lot better than random chance leve (which would be 0.25-ish in each cell) when predicting Classes 1 and 2. It’s fairly concerning that 32% of cases that were actually Class 2 were predicted to be Class 0. We would likely want to conduct a post-mortem with these cases to see if we could understand why our model was failing in this particular direction.

Decision trees, generally, are easily interpretable and easy to communicate with stakeholders. They also make no assumptions about the data, and can be applied in a large number of situations. Unfortunately, they often suffer from rapid overfitting to the data leading to poor generalizations to unseen data. In the next chapter, we’ll build on decision trees to talk about *ensemble* methods, where we use multiple trees to make a single prediction.