Credit: This note heavily uses material from the books An Introduction to Statistical Learning: with Applications in R (ISL2) and Elements of Statistical Learning: Data Mining, Inference, and Prediction (ESL2).

Display system information for reproducibility.

sessionInfo()
## R version 4.2.3 (2023-03-15)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Ventura 13.4
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## loaded via a namespace (and not attached):
##  [1] digest_0.6.31   R6_2.5.1        jsonlite_1.8.5  evaluate_0.21  
##  [5] cachem_1.0.8    rlang_1.1.1     cli_3.6.1       rstudioapi_0.14
##  [9] jquerylib_0.1.4 bslib_0.4.2     rmarkdown_2.22  tools_4.2.3    
## [13] xfun_0.39       yaml_2.3.7      fastmap_1.1.1   compiler_4.2.3 
## [17] htmltools_0.5.5 knitr_1.43      sass_0.4.6
library(tidyverse)
## ── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
## ✔ dplyr     1.1.2     ✔ readr     2.1.4
## ✔ forcats   1.0.0     ✔ stringr   1.5.0
## ✔ ggplot2   3.4.2     ✔ tibble    3.2.1
## ✔ lubridate 1.9.2     ✔ tidyr     1.3.0
## ✔ purrr     1.0.1     
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
## ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(faraway)

Overview

Pros and cons of decision tree

  • Tree-based methods are simple and useful for interpretation.

  • However they typically are not competitive with the best supervised learning approaches in terms of prediction accuracy.

  • Hence we also discuss bagging, random forests, and boosting. These methods grow multiple trees which are then combined to yield a single consensus prediction.

  • Combining a large number of trees can often result in dramatic improvements in prediction accuracy, at the expense of some loss of interpretation.

The basics of decision trees

Baseball player salary data Hitter.

Load Hitters data:

Hitters = read_csv("../data/Hitters.csv")
## Rows: 322 Columns: 20
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr  (3): League, Division, NewLeague
## dbl (17): AtBat, Hits, HmRun, Runs, RBI, Walks, Years, CAtBat, CHits, CHmRun...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
Hitters
## # A tibble: 322 × 20
##    AtBat  Hits HmRun  Runs   RBI Walks Years CAtBat CHits CHmRun CRuns  CRBI
##    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl>  <dbl> <dbl> <dbl>
##  1   293    66     1    30    29    14     1    293    66      1    30    29
##  2   315    81     7    24    38    39    14   3449   835     69   321   414
##  3   479   130    18    66    72    76     3   1624   457     63   224   266
##  4   496   141    20    65    78    37    11   5628  1575    225   828   838
##  5   321    87    10    39    42    30     2    396   101     12    48    46
##  6   594   169     4    74    51    35    11   4408  1133     19   501   336
##  7   185    37     1    23     8    21     2    214    42      1    30     9
##  8   298    73     0    24    24     7     3    509   108      0    41    37
##  9   323    81     6    26    32     8     2    341    86      6    32    34
## 10   401    92    17    49    66    65    13   5206  1332    253   784   890
## # ℹ 312 more rows
## # ℹ 8 more variables: CWalks <dbl>, League <chr>, Division <chr>,
## #   PutOuts <dbl>, Assists <dbl>, Errors <dbl>, Salary <dbl>, NewLeague <chr>

Visualize:

Hitters %>%
  arrange(Salary, decreasing = T) %>%
  ggplot(mapping = aes(x = Years, y = Hits, color = Salary)) + 
  scale_color_gradient(low="blue", high="red") + 
  geom_point() 

Who are those two outliers?

Hitters %>% filter(Hits < 10)
## # A tibble: 6 × 20
##   AtBat  Hits HmRun  Runs   RBI Walks Years CAtBat CHits CHmRun CRuns  CRBI
##   <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl> <dbl>  <dbl> <dbl> <dbl>
## 1    19     7     0     1     2     1     4     41    13      1     3     4
## 2    24     3     0     1     0     2     3    159    28      0    20    12
## 3    20     1     0     0     0     0     2     41     9      2     6     7
## 4    33     6     0     2     4     7     1     33     6      0     2     4
## 5    16     2     0     1     0     0     2     28     4      0     1     0
## 6    19     4     1     2     3     1     1     19     4      1     2     3
## # ℹ 8 more variables: CWalks <dbl>, League <chr>, Division <chr>,
## #   PutOuts <dbl>, Assists <dbl>, Errors <dbl>, Salary <dbl>, NewLeague <chr>

Tree-building process

Predictions

Top left: A partition of two-dimensional feature space that could not result from recursive binary splitting. Top right: The output of recursive binary splitting on a two-dimensional example. Bottom left: A tree corresponding to the partition in the top right panel. Bottom right: A perspective plot of the prediction surface corresponding to that tree.

Prunning

Summary: tree algorithm

  1. Use recursive binary splitting to grow a large tree on the training data, stopping only when each terminal node has fewer than some minimum number of observations.

  2. Apply cost complexity pruning to the large tree in order to obtain a sequence of best subtrees, as a function of \(\alpha\).

  3. Use \(K\)-fold cross-validation to choose \(\alpha\). For each \(k = 1,\ldots,K\):

    3.1 Repeat Steps 1 and 2 on the \((K-1)/K\) the fraction of the training data, excluding the \(k\)th fold.

    3.2 Evaluate the mean squared prediction error on the data in the left-out \(k\)th fold, as a function of \(\alpha\).

    Average the results, and pick \(\alpha\) to minimize the average error.

  4. Return the subtree from Step 2 that corresponds to the chosen value of \(\alpha\).

Hitters example (regression tree)

Workflow: pruning a regression tree.

Classification trees

Tree versus linear models

Top Row: True linear boundary; Bottom row: true non-linear boundary. Left column: linear model; Right column: tree-based model

Pros and cons of decision trees

Advantages:

  1. Trees are very easy to explain to people. In fact, they are even easier to explain than linear regression!

  2. Some people believe that decision trees more closely mirror human decision-making than other regression and classification approaches we learnt in this course.

  3. Trees can be displayed graphically, and are easily interpreted even by a non-expert (especially if they are small).

  4. Trees can easily handle qualitative predictors without the need to create dummy variables, although scikit-learn and xgboost don’t allow categorical predictors yet???

Disadvantages:

  1. Unfortunately, trees generally do not have the same level of predictive accuracy as some of the other regression and classification approaches.

  2. Additionally, trees can be very non-robust. In other words, a small change in the data can cause a large change in the final estimated tree.

Ensemble methods such as bagging, random forests, and boosting solve these issues.

Bagging

The test error (black and orange) is shown as a function of \(B\), the number of bootstrapped training sets used. Random forests were applied with \(m = \sqrt{p}\). The dashed line indicates the test error resulting from a single classification tree. The green and blue traces show the OOB error, which in this case is considerably lower.

Out-of-Bag (OOB) error estimation

Random forests

::: {#fig-random-forest-gene-expression}

Results from random forests for the fifteen-class gene expression data set with \(p=500\) predictors. The test error is displayed as a function of the number of trees. Each colored line corresponds to a different value of \(m\), the number of predictors available for splitting at each interior tree node. Random forests (\(m < p\)) lead to a slight improvement over bagging (\(m = p\)). A single classification tree has an error rate of 45.7%.

Boosting

Boosting for classification

See ESL Chapter 10.

Gene expression example (continued).

Results from performing boosting and random forests on the fifteen-class gene expression data set in order to predict cancer versus normal. The test error is displayed as a function of the number of trees. For the two boosted models, \(\lambda=0.01\). Depth-1 trees slightly outperform depth-2 trees, and both outperform the random forest, although the standard errors are around 0.02, making none of these differences significant. The test error rate for a single tree is 24%.

Tuning parameters for boosting

  1. The number of trees \(B\). Unlike bagging and random forests, boosting can overfit if \(B\) is too large, although this overfitting tends to occur slowly if at all. We use cross-validation to select \(B\).

  2. The shrinkage parameter or learning rate \(\lambda\), a small positive number. This controls the rate at which boosting learns. Typical values are 0.01 or 0.001, and the right choice can depend on the problem. Very small \(\lambda\) can require using a very large value of \(B\) in order to achieve good performance.

  3. The number of splits \(d\) in each tree, which controls the complexity of the boosted ensemble. Often \(d = 1\) or 2 works well, in which case each tree is a stump, consisting of a single split and resulting in an additive model. More generally \(d\) is the interaction depth, and controls the interaction order of the boosted model, since \(d\) splits can involve at most \(d\) variables.

Variable importance (VI) measure

Variable importance plot for the Heart data.

Summary