Machine Learning Workflow: Regression Trees

Biostat 200C

Author

Dr. Jin Zhou @ UCLA

Published

June 7, 2023

Load useful packages

library(tidyverse)
library(tidymodels)
library(broom)
library(gt)
library(patchwork)
library(tictoc)
library(ISLR2)
library(janitor)

# Load dunnr R package and set the ggplot theme
library(dunnr)
# extrafont::loadfonts(device = "win", quiet = TRUE)
# theme_set(theme_td_minimal())
# set_geom_fonts()
# set_palette()

Overview

We illustrate the typical machine learning workflow for regression trees using the Hitters data set from R ISLR2 package.

  1. Initial splitting to test and non-test sets.

  2. Pre-processing of data: not much is needed for regression trees.

  3. Tune the cost complexity pruning hyper-parameter(s) using 6-fold cross-validation (CV) on the non-test data.

  4. Choose the best model by CV and refit it on the whole non-test data.

  5. Final prediction on the test data.

Hitters data

A documentation of the Hitters data is here. The goal is to predict the log(Salary) (at opening of 1987 season) of MLB players from their performance metrics in the 1986-7 season.

hitters <- ISLR2::Hitters %>% janitor::clean_names()

# As per the text, we remove missing `salary` values and log-transform it
hitters <- hitters %>%
  filter(!is.na(salary)) %>%
  mutate(salary = log(salary))

glimpse(hitters)
Rows: 263
Columns: 20
$ at_bat     <int> 315, 479, 496, 321, 594, 185, 298, 323, 401, 574, 202, 418,…
$ hits       <int> 81, 130, 141, 87, 169, 37, 73, 81, 92, 159, 53, 113, 60, 43…
$ hm_run     <int> 7, 18, 20, 10, 4, 1, 0, 6, 17, 21, 4, 13, 0, 7, 20, 2, 8, 1…
$ runs       <int> 24, 66, 65, 39, 74, 23, 24, 26, 49, 107, 31, 48, 30, 29, 89…
$ rbi        <int> 38, 72, 78, 42, 51, 8, 24, 32, 66, 75, 26, 61, 11, 27, 75, …
$ walks      <int> 39, 76, 37, 30, 35, 21, 7, 8, 65, 59, 27, 47, 22, 30, 73, 1…
$ years      <int> 14, 3, 11, 2, 11, 2, 3, 2, 13, 10, 9, 4, 6, 13, 15, 5, 8, 1…
$ c_at_bat   <int> 3449, 1624, 5628, 396, 4408, 214, 509, 341, 5206, 4631, 187…
$ c_hits     <int> 835, 457, 1575, 101, 1133, 42, 108, 86, 1332, 1300, 467, 39…
$ c_hm_run   <int> 69, 63, 225, 12, 19, 1, 0, 6, 253, 90, 15, 41, 4, 36, 177, …
$ c_runs     <int> 321, 224, 828, 48, 501, 30, 41, 32, 784, 702, 192, 205, 309…
$ crbi       <int> 414, 266, 838, 46, 336, 9, 37, 34, 890, 504, 186, 204, 103,…
$ c_walks    <int> 375, 263, 354, 33, 194, 24, 12, 8, 866, 488, 161, 203, 207,…
$ league     <fct> N, A, N, N, A, N, A, N, A, A, N, N, A, N, N, A, N, N, A, N,…
$ division   <fct> W, W, E, E, W, E, W, W, E, E, W, E, E, E, W, W, W, E, W, W,…
$ put_outs   <int> 632, 880, 200, 805, 282, 76, 121, 143, 0, 238, 304, 211, 12…
$ assists    <int> 43, 82, 11, 40, 421, 127, 283, 290, 0, 445, 45, 11, 151, 45…
$ errors     <int> 10, 14, 3, 4, 25, 7, 9, 19, 0, 22, 11, 7, 6, 8, 10, 16, 2, …
$ salary     <dbl> 6.163315, 6.173786, 6.214608, 4.516339, 6.620073, 4.248495,…
$ new_league <fct> N, A, N, N, A, A, A, N, A, A, N, N, A, N, N, A, N, N, N, N,…
# install.packages("tree")
library(tree)

hitters_tree <- tree(salary ~ years + hits, data = hitters,
                     # In order to limit the tree to just two partitions,
                     #  need to set the `control` option
                     control = tree.control(nrow(hitters), minsize = 100))

Use the built-in plot() to visualize the tree in Figure 8.1:

plot(hitters_tree)
text(hitters_tree)

To work with the regions, there is no broom::tidy() method for tree objects, but we can get the cuts from the frame$splits object:

hitters_tree$frame$splits
     cutleft  cutright
[1,] "<4.5"   ">4.5"  
[2,] ""       ""      
[3,] "<117.5" ">117.5"
[4,] ""       ""      
[5,] ""       ""      
splits <- hitters_tree$frame$splits %>%
  as_tibble() %>%
  filter(cutleft != "") %>%
  mutate(val = readr::parse_number(cutleft)) %>%
  pull(val)
splits
[1]   4.5 117.5
hitters %>%
  ggplot2::ggplot(aes(x = years, y = hits)) +
  geom_point(color = td_colors$nice$soft_orange) +
  geom_vline(xintercept = splits[1], size = 1, color = "forestgreen") +
  geom_segment(aes(x = splits[1], xend = 25, y = splits[2], yend = splits[2]),
               size = 1, color = "forestgreen") +
  annotate("text", x = 10, y = 50, label = "R[2]", size = 6, parse = TRUE) +
  annotate("text", x = 10, y = 200, label = "R[3]", size = 6, parse = TRUE) +
  annotate("text", x = 2, y = 118, label = "R[1]", size = 6, parse = TRUE) +
  coord_cartesian(xlim = c(0, 25), ylim = c(0, 240)) +
  scale_x_continuous(breaks = c(1, 4.5, 24)) +
  scale_y_continuous(breaks = c(1, 117.5, 238))
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
ℹ Please use `linewidth` instead.

The regions \(R_1\), \(R_2\), and \(R_3\) are known as terminal nodes or leaves of the tree. The splits along the way are referred to as internal nodes – the connections between nodes are called branches.

A key advantage of a simple decision tree like this is its ease of interpretation:

We might interpret the regression tree displayed in Figure 8.1 as follows: Years is the most important factor in determining Salary, and players with less experience earn lower salaries than more experienced players. Given that a player is less experienced, the number of hits that he made in the previous year seems to play little role in his salary. But among players who have been in the major leagues for five or more years, the number of hits made in the previous year does affect salary, and players who made more hits last year tend to have higher salaries.

Initial split into test and non-test sets

set.seed(-203)
hitters_split <- initial_split(hitters,
                               # Bumped up the prop to get 132 training observations
                               prop = 0.505)

hitters_train <- training(hitters_split)
hitters_test <- testing(hitters_split)

hitters_resamples <- vfold_cv(hitters_train, v = 6)

Then fitting a decision tree with six features

hitters_train_tree <- tree(
  salary ~ years + hits + rbi + put_outs + walks + runs,
  data = hitters_train,
)
plot(hitters_train_tree)
text(hitters_train_tree, digits = 3)

Do the same with each CV split:

hitters_resamples_tree <-
  # Compile all of the analysis data sets from the six splits
  map_dfr(hitters_resamples$splits, analysis, .id = "split") %>%
  # For each split...
  group_by(split) %>%
  nest() %>%
  mutate(
    # ... fit a tree to the analysis set
    tree_mod = map(
      data,
      ~ tree(
        salary ~ years + hits + rbi + put_outs + walks + runs,
        data = .x,
      )
    )
  )

Next, we prune the large tree above from 3 terminal nodes down to 1. For this, I’ll vary the best parameter in the prune.tree() function:

hitters_tree_pruned <- 
  tibble(n_terminal = 1:10) %>%
  mutate(
    train_tree_pruned = map(n_terminal,
                            ~ prune.tree(hitters_train_tree, best = .x)),
  )
hitters_tree_pruned
# A tibble: 10 × 2
   n_terminal train_tree_pruned
        <int> <list>           
 1          1 <singlend>       
 2          2 <tree>           
 3          3 <tree>           
 4          4 <tree>           
 5          5 <tree>           
 6          6 <tree>           
 7          7 <tree>           
 8          8 <tree>           
 9          9 <tree>           
10         10 <tree>           

Note that, for n_terminal = 1, the object is singlend, not tree. This makes sense – a single node can’t really be called a tree – but unfortunately it means that I can’t use the predict() function to calculate MSE later on. Mathematically, a single node is just a prediction of the mean of the training set, so I will replace n_terminal = 1 with a lm model with just an intercept:

hitters_tree_pruned <- hitters_tree_pruned %>%
  mutate(
    train_tree_pruned = ifelse(
      n_terminal == 1,
      list(lm(salary ~ 1, data = hitters_train)),
      train_tree_pruned
    )
  )

Do the same for each CV split:

hitters_resamples_tree_pruned <- hitters_resamples_tree %>%
  crossing(n_terminal = 1:10) %>%
  mutate(
    tree_pruned = map2(tree_mod, n_terminal,
                       ~ prune.tree(.x, best = .y)),
    # As above, replace the single node trees with lm
    tree_pruned = ifelse(
      n_terminal == 1,
      map(data, ~ lm(salary ~ 1, data = .x)),
      tree_pruned
    )
  )
Warning: There were 3 warnings in `mutate()`.
The first warning was:
ℹ In argument: `tree_pruned = map2(tree_mod, n_terminal, ~prune.tree(.x, best =
  .y))`.
Caused by warning in `prune.tree()`:
! best is bigger than tree size
ℹ Run `dplyr::last_dplyr_warnings()` to see the 2 remaining warnings.

Note the warnings. This says some of the models fit to the CV splits had 10 or fewer terminal nodes already, and so no pruning was performed.

Finally, compute the MSE for the different data sets. The training and testing sets:

# Simple helper function to compute mean squared error
calc_mse <- function(mod, data) {
  mean((predict(mod, newdata = data) - data$salary)^2)
}

hitters_tree_pruned_mse <- hitters_tree_pruned %>%
  mutate(
    train_mse = map_dbl(
      train_tree_pruned,
      ~ calc_mse(.x, hitters_train)
    ),
    test_mse = map_dbl(
      train_tree_pruned,
      ~ calc_mse(.x, hitters_test)
    )
  )
hitters_tree_pruned_mse
# A tibble: 10 × 4
   n_terminal train_tree_pruned train_mse test_mse
        <int> <list>                <dbl>    <dbl>
 1          1 <lm>                  0.671    0.906
 2          2 <tree>                0.400    0.487
 3          3 <tree>                0.329    0.377
 4          4 <tree>                0.280    0.443
 5          5 <tree>                0.280    0.443
 6          6 <tree>                0.264    0.433
 7          7 <tree>                0.252    0.401
 8          8 <tree>                0.233    0.393
 9          9 <tree>                0.233    0.393
10         10 <tree>                0.225    0.385

And the CV splits:

hitters_resamples_tree_pruned_mse <- hitters_resamples_tree_pruned %>%
  select(split, n_terminal, tree_pruned) %>%
  left_join(
    map_dfr(hitters_resamples$splits, assessment, .id = "split") %>%
      group_by(split) %>%
      nest() %>%
      rename(assessment_data = data),
    by = "split"
  ) %>%
  mutate(
    cv_mse = map2_dbl(
      tree_pruned, assessment_data,
      ~ calc_mse(.x, .y)
    )
  ) %>%
  group_by(n_terminal) %>%
  summarise(cv_mse = mean(cv_mse), .groups = "drop")

Finally, put it all together (without standard error bars):

hitters_tree_pruned_mse %>%
  select(-train_tree_pruned) %>%
  left_join(hitters_resamples_tree_pruned_mse, by = "n_terminal") %>%
  pivot_longer(cols = c(train_mse, test_mse, cv_mse), names_to = "data_set") %>%
  mutate(
    data_set = factor(data_set,
                      levels = c("train_mse", "cv_mse", "test_mse"),
                      labels = c("Training", "Cross-validation", "Test"))
  ) %>%
  ggplot(aes(x = n_terminal, y = value, color = data_set)) +
  geom_point(size = 3) +
  geom_line(size = 1) +
  scale_y_continuous("Mean squared error", breaks = seq(0, 1.0, 0.2)) +
  expand_limits(y = c(0, 1.0)) +
  scale_x_continuous("Tree size", breaks = seq(2, 10, 2)) +
  scale_color_manual(NULL, values = c("black", "darkgreen", "darkorange")) +
  theme(legend.position = c(0.7, 0.8))