library(tidyverse)
library(tidymodels)
library(broom)
library(gt)
library(patchwork)
library(tictoc)
library(ISLR2)
library(janitor)
# Load dunnr R package and set the ggplot theme
library(dunnr)
# extrafont::loadfonts(device = "win", quiet = TRUE)
# theme_set(theme_td_minimal())
# set_geom_fonts()
# set_palette()
Machine Learning Workflow: Regression Trees
Biostat 200C
Load useful packages
Overview
We illustrate the typical machine learning workflow for regression trees using the Hitters
data set from R ISLR2
package.
Initial splitting to test and non-test sets.
Pre-processing of data: not much is needed for regression trees.
Tune the cost complexity pruning hyper-parameter(s) using 6-fold cross-validation (CV) on the non-test data.
Choose the best model by CV and refit it on the whole non-test data.
Final prediction on the test data.
Hitters data
A documentation of the Hitters
data is here. The goal is to predict the log(Salary) (at opening of 1987 season) of MLB players from their performance metrics in the 1986-7 season.
<- ISLR2::Hitters %>% janitor::clean_names()
hitters
# As per the text, we remove missing `salary` values and log-transform it
<- hitters %>%
hitters filter(!is.na(salary)) %>%
mutate(salary = log(salary))
glimpse(hitters)
Rows: 263
Columns: 20
$ at_bat <int> 315, 479, 496, 321, 594, 185, 298, 323, 401, 574, 202, 418,…
$ hits <int> 81, 130, 141, 87, 169, 37, 73, 81, 92, 159, 53, 113, 60, 43…
$ hm_run <int> 7, 18, 20, 10, 4, 1, 0, 6, 17, 21, 4, 13, 0, 7, 20, 2, 8, 1…
$ runs <int> 24, 66, 65, 39, 74, 23, 24, 26, 49, 107, 31, 48, 30, 29, 89…
$ rbi <int> 38, 72, 78, 42, 51, 8, 24, 32, 66, 75, 26, 61, 11, 27, 75, …
$ walks <int> 39, 76, 37, 30, 35, 21, 7, 8, 65, 59, 27, 47, 22, 30, 73, 1…
$ years <int> 14, 3, 11, 2, 11, 2, 3, 2, 13, 10, 9, 4, 6, 13, 15, 5, 8, 1…
$ c_at_bat <int> 3449, 1624, 5628, 396, 4408, 214, 509, 341, 5206, 4631, 187…
$ c_hits <int> 835, 457, 1575, 101, 1133, 42, 108, 86, 1332, 1300, 467, 39…
$ c_hm_run <int> 69, 63, 225, 12, 19, 1, 0, 6, 253, 90, 15, 41, 4, 36, 177, …
$ c_runs <int> 321, 224, 828, 48, 501, 30, 41, 32, 784, 702, 192, 205, 309…
$ crbi <int> 414, 266, 838, 46, 336, 9, 37, 34, 890, 504, 186, 204, 103,…
$ c_walks <int> 375, 263, 354, 33, 194, 24, 12, 8, 866, 488, 161, 203, 207,…
$ league <fct> N, A, N, N, A, N, A, N, A, A, N, N, A, N, N, A, N, N, A, N,…
$ division <fct> W, W, E, E, W, E, W, W, E, E, W, E, E, E, W, W, W, E, W, W,…
$ put_outs <int> 632, 880, 200, 805, 282, 76, 121, 143, 0, 238, 304, 211, 12…
$ assists <int> 43, 82, 11, 40, 421, 127, 283, 290, 0, 445, 45, 11, 151, 45…
$ errors <int> 10, 14, 3, 4, 25, 7, 9, 19, 0, 22, 11, 7, 6, 8, 10, 16, 2, …
$ salary <dbl> 6.163315, 6.173786, 6.214608, 4.516339, 6.620073, 4.248495,…
$ new_league <fct> N, A, N, N, A, A, A, N, A, A, N, N, A, N, N, A, N, N, N, N,…
# install.packages("tree")
library(tree)
<- tree(salary ~ years + hits, data = hitters,
hitters_tree # In order to limit the tree to just two partitions,
# need to set the `control` option
control = tree.control(nrow(hitters), minsize = 100))
Use the built-in plot() to visualize the tree in Figure 8.1:
plot(hitters_tree)
text(hitters_tree)
To work with the regions, there is no broom::tidy()
method for tree objects, but we can get the cuts from the frame$splits
object:
$frame$splits hitters_tree
cutleft cutright
[1,] "<4.5" ">4.5"
[2,] "" ""
[3,] "<117.5" ">117.5"
[4,] "" ""
[5,] "" ""
<- hitters_tree$frame$splits %>%
splits as_tibble() %>%
filter(cutleft != "") %>%
mutate(val = readr::parse_number(cutleft)) %>%
pull(val)
splits
[1] 4.5 117.5
%>%
hitters ::ggplot(aes(x = years, y = hits)) +
ggplot2geom_point(color = td_colors$nice$soft_orange) +
geom_vline(xintercept = splits[1], size = 1, color = "forestgreen") +
geom_segment(aes(x = splits[1], xend = 25, y = splits[2], yend = splits[2]),
size = 1, color = "forestgreen") +
annotate("text", x = 10, y = 50, label = "R[2]", size = 6, parse = TRUE) +
annotate("text", x = 10, y = 200, label = "R[3]", size = 6, parse = TRUE) +
annotate("text", x = 2, y = 118, label = "R[1]", size = 6, parse = TRUE) +
coord_cartesian(xlim = c(0, 25), ylim = c(0, 240)) +
scale_x_continuous(breaks = c(1, 4.5, 24)) +
scale_y_continuous(breaks = c(1, 117.5, 238))
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.
The regions \(R_1\), \(R_2\), and \(R_3\) are known as terminal nodes or leaves of the tree. The splits along the way are referred to as internal nodes – the connections between nodes are called branches.
A key advantage of a simple decision tree like this is its ease of interpretation:
We might interpret the regression tree displayed in Figure 8.1 as follows: Years is the most important factor in determining Salary, and players with less experience earn lower salaries than more experienced players. Given that a player is less experienced, the number of hits that he made in the previous year seems to play little role in his salary. But among players who have been in the major leagues for five or more years, the number of hits made in the previous year does affect salary, and players who made more hits last year tend to have higher salaries.
Initial split into test and non-test sets
set.seed(-203)
<- initial_split(hitters,
hitters_split # Bumped up the prop to get 132 training observations
prop = 0.505)
<- training(hitters_split)
hitters_train <- testing(hitters_split)
hitters_test
<- vfold_cv(hitters_train, v = 6) hitters_resamples
Then fitting a decision tree with six features
<- tree(
hitters_train_tree ~ years + hits + rbi + put_outs + walks + runs,
salary data = hitters_train,
)plot(hitters_train_tree)
text(hitters_train_tree, digits = 3)
Do the same with each CV split:
<-
hitters_resamples_tree # Compile all of the analysis data sets from the six splits
map_dfr(hitters_resamples$splits, analysis, .id = "split") %>%
# For each split...
group_by(split) %>%
nest() %>%
mutate(
# ... fit a tree to the analysis set
tree_mod = map(
data,~ tree(
~ years + hits + rbi + put_outs + walks + runs,
salary data = .x,
)
) )
Next, we prune the large tree above from 3 terminal nodes down to 1. For this, I’ll vary the best parameter in the prune.tree() function:
<-
hitters_tree_pruned tibble(n_terminal = 1:10) %>%
mutate(
train_tree_pruned = map(n_terminal,
~ prune.tree(hitters_train_tree, best = .x)),
) hitters_tree_pruned
# A tibble: 10 × 2
n_terminal train_tree_pruned
<int> <list>
1 1 <singlend>
2 2 <tree>
3 3 <tree>
4 4 <tree>
5 5 <tree>
6 6 <tree>
7 7 <tree>
8 8 <tree>
9 9 <tree>
10 10 <tree>
Note that, for n_terminal = 1
, the object is singlend, not tree. This makes sense – a single node can’t really be called a tree – but unfortunately it means that I can’t use the predict() function to calculate MSE later on. Mathematically, a single node is just a prediction of the mean of the training set, so I will replace n_terminal = 1
with a lm
model with just an intercept:
<- hitters_tree_pruned %>%
hitters_tree_pruned mutate(
train_tree_pruned = ifelse(
== 1,
n_terminal list(lm(salary ~ 1, data = hitters_train)),
train_tree_pruned
) )
Do the same for each CV split:
<- hitters_resamples_tree %>%
hitters_resamples_tree_pruned crossing(n_terminal = 1:10) %>%
mutate(
tree_pruned = map2(tree_mod, n_terminal,
~ prune.tree(.x, best = .y)),
# As above, replace the single node trees with lm
tree_pruned = ifelse(
== 1,
n_terminal map(data, ~ lm(salary ~ 1, data = .x)),
tree_pruned
) )
Warning: There were 3 warnings in `mutate()`.
The first warning was:
ℹ In argument: `tree_pruned = map2(tree_mod, n_terminal, ~prune.tree(.x, best =
.y))`.
Caused by warning in `prune.tree()`:
! best is bigger than tree size
ℹ Run `dplyr::last_dplyr_warnings()` to see the 2 remaining warnings.
Note the warnings. This says some of the models fit to the CV splits had 10 or fewer terminal nodes already, and so no pruning was performed.
Finally, compute the MSE for the different data sets. The training and testing sets:
# Simple helper function to compute mean squared error
<- function(mod, data) {
calc_mse mean((predict(mod, newdata = data) - data$salary)^2)
}
<- hitters_tree_pruned %>%
hitters_tree_pruned_mse mutate(
train_mse = map_dbl(
train_tree_pruned,~ calc_mse(.x, hitters_train)
),test_mse = map_dbl(
train_tree_pruned,~ calc_mse(.x, hitters_test)
)
) hitters_tree_pruned_mse
# A tibble: 10 × 4
n_terminal train_tree_pruned train_mse test_mse
<int> <list> <dbl> <dbl>
1 1 <lm> 0.671 0.906
2 2 <tree> 0.400 0.487
3 3 <tree> 0.329 0.377
4 4 <tree> 0.280 0.443
5 5 <tree> 0.280 0.443
6 6 <tree> 0.264 0.433
7 7 <tree> 0.252 0.401
8 8 <tree> 0.233 0.393
9 9 <tree> 0.233 0.393
10 10 <tree> 0.225 0.385
And the CV splits:
<- hitters_resamples_tree_pruned %>%
hitters_resamples_tree_pruned_mse select(split, n_terminal, tree_pruned) %>%
left_join(
map_dfr(hitters_resamples$splits, assessment, .id = "split") %>%
group_by(split) %>%
nest() %>%
rename(assessment_data = data),
by = "split"
%>%
) mutate(
cv_mse = map2_dbl(
tree_pruned, assessment_data,~ calc_mse(.x, .y)
)%>%
) group_by(n_terminal) %>%
summarise(cv_mse = mean(cv_mse), .groups = "drop")
Finally, put it all together (without standard error bars):
%>%
hitters_tree_pruned_mse select(-train_tree_pruned) %>%
left_join(hitters_resamples_tree_pruned_mse, by = "n_terminal") %>%
pivot_longer(cols = c(train_mse, test_mse, cv_mse), names_to = "data_set") %>%
mutate(
data_set = factor(data_set,
levels = c("train_mse", "cv_mse", "test_mse"),
labels = c("Training", "Cross-validation", "Test"))
%>%
) ggplot(aes(x = n_terminal, y = value, color = data_set)) +
geom_point(size = 3) +
geom_line(size = 1) +
scale_y_continuous("Mean squared error", breaks = seq(0, 1.0, 0.2)) +
expand_limits(y = c(0, 1.0)) +
scale_x_continuous("Tree size", breaks = seq(2, 10, 2)) +
scale_color_manual(NULL, values = c("black", "darkgreen", "darkorange")) +
theme(legend.position = c(0.7, 0.8))