Credit: This note heavily uses material from the books An Introduction to Statistical Learning: with Applications in R (ISL2) and Elements of Statistical Learning: Data Mining, Inference, and Prediction (ESL2).

Display system information for reproducibility.

sessionInfo()
## R version 4.2.3 (2023-03-15)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Ventura 13.3.1
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## loaded via a namespace (and not attached):
##  [1] digest_0.6.31   R6_2.5.1        jsonlite_1.8.4  evaluate_0.21  
##  [5] cachem_1.0.8    rlang_1.1.1     cli_3.6.1       rstudioapi_0.14
##  [9] jquerylib_0.1.4 bslib_0.4.2     rmarkdown_2.21  tools_4.2.3    
## [13] xfun_0.39       yaml_2.3.7      fastmap_1.1.1   compiler_4.2.3 
## [17] htmltools_0.5.5 knitr_1.42      sass_0.4.6

Overview

library(gtsummary)
library(ISLR2)
library(tidyverse)

# Convert to tibble
Wage <- as_tibble(Wage) %>% print(width = Inf)
## # A tibble: 3,000 × 11
##     year   age maritl           race     education       region            
##    <int> <int> <fct>            <fct>    <fct>           <fct>             
##  1  2006    18 1. Never Married 1. White 1. < HS Grad    2. Middle Atlantic
##  2  2004    24 1. Never Married 1. White 4. College Grad 2. Middle Atlantic
##  3  2003    45 2. Married       1. White 3. Some College 2. Middle Atlantic
##  4  2003    43 2. Married       3. Asian 4. College Grad 2. Middle Atlantic
##  5  2005    50 4. Divorced      1. White 2. HS Grad      2. Middle Atlantic
##  6  2008    54 2. Married       1. White 4. College Grad 2. Middle Atlantic
##  7  2009    44 2. Married       4. Other 3. Some College 2. Middle Atlantic
##  8  2008    30 1. Never Married 3. Asian 3. Some College 2. Middle Atlantic
##  9  2006    41 1. Never Married 2. Black 3. Some College 2. Middle Atlantic
## 10  2004    52 2. Married       1. White 2. HS Grad      2. Middle Atlantic
##    jobclass       health         health_ins logwage  wage
##    <fct>          <fct>          <fct>        <dbl> <dbl>
##  1 1. Industrial  1. <=Good      2. No         4.32  75.0
##  2 2. Information 2. >=Very Good 2. No         4.26  70.5
##  3 1. Industrial  1. <=Good      1. Yes        4.88 131. 
##  4 2. Information 2. >=Very Good 1. Yes        5.04 155. 
##  5 2. Information 1. <=Good      1. Yes        4.32  75.0
##  6 2. Information 2. >=Very Good 1. Yes        4.85 127. 
##  7 1. Industrial  2. >=Very Good 1. Yes        5.13 170. 
##  8 2. Information 1. <=Good      1. Yes        4.72 112. 
##  9 2. Information 2. >=Very Good 1. Yes        4.78 119. 
## 10 2. Information 2. >=Very Good 1. Yes        4.86 129. 
## # ℹ 2,990 more rows
# Summary statistics
Wage %>% tbl_summary()
Characteristic N = 3,0001
year
    2003 513 (17%)
    2004 485 (16%)
    2005 447 (15%)
    2006 392 (13%)
    2007 386 (13%)
    2008 388 (13%)
    2009 389 (13%)
age 42 (34, 51)
maritl
    1. Never Married 648 (22%)
    2. Married 2,074 (69%)
    3. Widowed 19 (0.6%)
    4. Divorced 204 (6.8%)
    5. Separated 55 (1.8%)
race
    1. White 2,480 (83%)
    2. Black 293 (9.8%)
    3. Asian 190 (6.3%)
    4. Other 37 (1.2%)
education
    1. < HS Grad 268 (8.9%)
    2. HS Grad 971 (32%)
    3. Some College 650 (22%)
    4. College Grad 685 (23%)
    5. Advanced Degree 426 (14%)
region
    1. New England 0 (0%)
    2. Middle Atlantic 3,000 (100%)
    3. East North Central 0 (0%)
    4. West North Central 0 (0%)
    5. South Atlantic 0 (0%)
    6. East South Central 0 (0%)
    7. West South Central 0 (0%)
    8. Mountain 0 (0%)
    9. Pacific 0 (0%)
jobclass
    1. Industrial 1,544 (51%)
    2. Information 1,456 (49%)
health
    1. <=Good 858 (29%)
    2. >=Very Good 2,142 (71%)
health_ins
    1. Yes 2,083 (69%)
    2. No 917 (31%)
logwage 4.65 (4.45, 4.86)
wage 105 (85, 129)
1 n (%); Median (IQR)
# Plot wage ~ age, GAM fit is display when n >1000
Wage %>%
  ggplot(mapping = aes(x = age, y = wage)) + 
  geom_point() + 
  geom_smooth() +
  labs(title = "Wage changes nonlinearly with age",
       x = "Age",
       y = "Wage (k$)")

Polynomial regression

\[ y_i = \beta_0 + \beta_1 x_i + \beta_2 x_i^2 + \cdots + \beta_d x_i^d + \epsilon_i. \]

# Plot wage ~ age, display order-4 polynomial fit
Wage %>%
  ggplot(mapping = aes(x = age, y = wage)) + 
  geom_point() + 
  geom_smooth(
    method = "lm",
    formula = y ~ poly(x, degree = 4)
    ) +
  labs(
    title = "Degree-4 Polynomial",
    x = "Age",
    y = "Wage (k$)"
    )

# poly(age, 4) constructs orthogonal polynomial of degree 1 to degree, all orthogonal to the constant
lmod <- lm(wage ~ poly(age, degree = 4), data = Wage)
summary(lmod)
## 
## Call:
## lm(formula = wage ~ poly(age, degree = 4), data = Wage)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -98.707 -24.626  -4.993  15.217 203.693 
## 
## Coefficients:
##                         Estimate Std. Error t value Pr(>|t|)    
## (Intercept)             111.7036     0.7287 153.283  < 2e-16 ***
## poly(age, degree = 4)1  447.0679    39.9148  11.201  < 2e-16 ***
## poly(age, degree = 4)2 -478.3158    39.9148 -11.983  < 2e-16 ***
## poly(age, degree = 4)3  125.5217    39.9148   3.145  0.00168 ** 
## poly(age, degree = 4)4  -77.9112    39.9148  -1.952  0.05104 .  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 39.91 on 2995 degrees of freedom
## Multiple R-squared:  0.08626,    Adjusted R-squared:  0.08504 
## F-statistic: 70.69 on 4 and 2995 DF,  p-value: < 2.2e-16
# poly(age, 4, raw = TRUE) procudes raw othogonal polynomial, which match Python
lmod <- lm(wage ~ poly(age, degree = 4, raw = TRUE), data = Wage)
summary(lmod)
## 
## Call:
## lm(formula = wage ~ poly(age, degree = 4, raw = TRUE), data = Wage)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -98.707 -24.626  -4.993  15.217 203.693 
## 
## Coefficients:
##                                      Estimate Std. Error t value Pr(>|t|)    
## (Intercept)                        -1.842e+02  6.004e+01  -3.067 0.002180 ** 
## poly(age, degree = 4, raw = TRUE)1  2.125e+01  5.887e+00   3.609 0.000312 ***
## poly(age, degree = 4, raw = TRUE)2 -5.639e-01  2.061e-01  -2.736 0.006261 ** 
## poly(age, degree = 4, raw = TRUE)3  6.811e-03  3.066e-03   2.221 0.026398 *  
## poly(age, degree = 4, raw = TRUE)4 -3.204e-05  1.641e-05  -1.952 0.051039 .  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 39.91 on 2995 degrees of freedom
## Multiple R-squared:  0.08626,    Adjusted R-squared:  0.08504 
## F-statistic: 70.69 on 4 and 2995 DF,  p-value: < 2.2e-16

:::

library(splines)

# Plot wage ~ age
Wage %>%
  ggplot(mapping = aes(x = age, y = wage)) + 
  geom_point(alpha = 0.25) + 
  # Polynomial regression with degree 14
  geom_smooth(
    method = "lm",
    formula = y ~ poly(x, degree = 14),
    color = "blue"
    ) +
  # Natural cubic spline
  geom_smooth(
    method = "lm",
    formula = y ~ ns(x, df = 14),
    color = "red"
    ) +  
  labs(
    title = "Natural cubic spline (red) vs polynomial regression (blue)",
    subtitle = "Both have df=15",
    x = "Age",
    y = "Wage (k$)"
    )

Piecewise polynomials (regression splines)

Linear spline

  • A linear spline with knots at \(\xi_k\), \(k = 1,\ldots,K\), is a piecewise linear polynomial continuous at each knot.

  • We can represent this model as \[ y_i = \beta_0 + \beta_1 b_1(x_i) + \beta_2 b_2(x_i) + \cdots + \beta_{K+1} b_{K+1}(x_i) + \epsilon_i, \] where \(b_k\) are basis functions:
    \[\begin{eqnarray*} b_1(x_i) &=& x_i \\ b_{k+1}(x_i) &=& (x_i - \xi_k)_+, \quad k=1,\ldots,K. \end{eqnarray*}\] Here \((\cdot)_k\) means positive part \[ (x_i - \xi_i)_+ = \begin{cases} x_i - \xi_k & \text{if } x_i > \xi_k \\ 0 & \text{otherwise} \end{cases}. \]

Cubic splines

  • A cubic spline with knots at \(\xi_k\), \(k = 1,\ldots,K\), is a piecewise cubic polynomial with continuous derivatives up to order 2 at each knot.

  • Again we can represent this model with truncated power basis functions \[ y_i = \beta_0 + \beta_1 b_1(x_i) + \beta_2 b_2(x_i) + \cdots + \beta_{K+3} b_{K+3}(x_i) + \epsilon_i, \] with \[\begin{eqnarray*} b_1(x_i) &=& x_i \\ b_2(x_i) &=& x_i^2 \\ b_3(x_i) &=& x_i^3 \\ b_{k+3}(x_i) &=& (x_i - \xi_k)_+^3, \quad k = 1,\ldots,K, \end{eqnarray*}\] where \[ (x_i - \xi_i)_+^3 = \begin{cases} (x_i - \xi_k)^3 & \text{if } x_i > \xi_k \\ 0 & \text{otherwise} \end{cases}. \]

  • A cubic spline with \(K\) knots costs \(K+4\) parameters or degrees of freedom. That is \(4(K+1)\) polynomial coefficients minus \(3K\) constraints.

  • While the truncated power basis is conceptually simple, it is not too attractive numerically: powers of large numbers can lead to severe rounding problems. In practice, B-spline basis functions are preferred for their computational efficiency. See ESL Chapter 5 Appendix.

Natural cubic splines

  • Splines can have high variance at the outer range of the predictors.

  • A natural cubic spline extrapolates linearly beyond the boundary knots. This adds \(4 = 2 \times 2\) extra constraints, and allows us to put more internal knots for the same degrees of freedom as a regular cubic spline.

  • A natural spline with \(K\) knots has \(K\) degrees of freedom.

library(splines)

# Plot wage ~ age
Wage %>%
  ggplot(mapping = aes(x = age, y = wage)) + 
  geom_point(alpha = 0.25) + 
  # Cubic spline
  geom_smooth(
    method = "lm",
    formula = y ~ bs(x, knots = c(25, 40, 60)),
    color = "blue"
    ) +
  # Natural cubic spline
  geom_smooth(
    method = "lm",
    formula = y ~ ns(x, knots = c(25, 40, 60)),
    color = "red"
    ) +  
  labs(
    title = "Natural cubic spline fit (red) vs cubic spline fit (blue)",
    x = "Age",
    y = "Wage (k$)"
    )

Knot placement

  • One strategy is to decide \(K\), the number of knots, and then place them at appropriate quantiles of the observed \(X\).

  • In practice users often specify the degree of freedom and let software choose the number of knots and locations.

Smoothing splines

ggformula package supplies geom_spline function for displaying smoothing spline fits.

library(ggformula)
## Loading required package: ggstance
## 
## Attaching package: 'ggstance'
## The following objects are masked from 'package:ggplot2':
## 
##     geom_errorbarh, GeomErrorbarh
## Loading required package: scales
## 
## Attaching package: 'scales'
## The following object is masked from 'package:purrr':
## 
##     discard
## The following object is masked from 'package:readr':
## 
##     col_factor
## Loading required package: ggridges
## 
## New to ggformula?  Try the tutorials: 
##  learnr::run_tutorial("introduction", package = "ggformula")
##  learnr::run_tutorial("refining", package = "ggformula")
library(splines)

# Plot wage ~ age
Wage %>%
  ggplot(mapping = aes(x = age, y = wage)) + 
  geom_point(alpha = 0.25) + 
  # Smoothing spline with df = 16
  geom_spline(
      df = 16,
      color = "red"
    ) +
  # Smoothing spline with GCV tuned df
  geom_spline(
    # df = 6.8,
    cv = TRUE,
    color = "blue"
    ) +
  labs(
    title = "Smoothing spline with df=16 (red) vs LOOCV tuned df=6.8 (blue)",
    x = "Age",
    y = "Wage (k$)"
    )
## Warning in smooth.spline(data$x, data$y, w = weight, spar = spar, cv = cv, :
## cross-validation with non-unique 'x' values seems doubtful

Local regression

Generalized additive model (GAM)

Natural splines for year and age.

gam_mod <- lm(
  wage ~ ns(year, df = 4) + ns(age, df = 5) + education,
  data = Wage
  )
summary(gam_mod)
## 
## Call:
## lm(formula = wage ~ ns(year, df = 4) + ns(age, df = 5) + education, 
##     data = Wage)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -120.513  -19.608   -3.583   14.112  214.535 
## 
## Coefficients:
##                             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)                   46.949      4.704   9.980  < 2e-16 ***
## ns(year, df = 4)1              8.625      3.466   2.488  0.01289 *  
## ns(year, df = 4)2              3.762      2.959   1.271  0.20369    
## ns(year, df = 4)3              8.127      4.211   1.930  0.05375 .  
## ns(year, df = 4)4              6.806      2.397   2.840  0.00455 ** 
## ns(age, df = 5)1              45.170      4.193  10.771  < 2e-16 ***
## ns(age, df = 5)2              38.450      5.076   7.575 4.78e-14 ***
## ns(age, df = 5)3              34.239      4.383   7.813 7.69e-15 ***
## ns(age, df = 5)4              48.678     10.572   4.605 4.31e-06 ***
## ns(age, df = 5)5               6.557      8.367   0.784  0.43328    
## education2. HS Grad           10.983      2.430   4.520 6.43e-06 ***
## education3. Some College      23.473      2.562   9.163  < 2e-16 ***
## education4. College Grad      38.314      2.547  15.042  < 2e-16 ***
## education5. Advanced Degree   62.554      2.761  22.654  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 35.16 on 2986 degrees of freedom
## Multiple R-squared:  0.293,  Adjusted R-squared:  0.2899 
## F-statistic:  95.2 on 13 and 2986 DF,  p-value: < 2.2e-16

Smoothing splines for year and age.

library(gam)
## Loading required package: foreach
## 
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
## 
##     accumulate, when
## Loaded gam 1.22-2
gam_mod <- gam(
  wage ~ s(year, 4) + s(age, 5) + education,
  data = Wage
  )
summary(gam_mod)
## 
## Call: gam(formula = wage ~ s(year, 4) + s(age, 5) + education, data = Wage)
## Deviance Residuals:
##     Min      1Q  Median      3Q     Max 
## -119.43  -19.70   -3.33   14.17  213.48 
## 
## (Dispersion Parameter for gaussian family taken to be 1235.69)
## 
##     Null Deviance: 5222086 on 2999 degrees of freedom
## Residual Deviance: 3689770 on 2986 degrees of freedom
## AIC: 29887.75 
## 
## Number of Local Scoring Iterations: NA 
## 
## Anova for Parametric Effects
##              Df  Sum Sq Mean Sq F value    Pr(>F)    
## s(year, 4)    1   27162   27162  21.981 2.877e-06 ***
## s(age, 5)     1  195338  195338 158.081 < 2.2e-16 ***
## education     4 1069726  267432 216.423 < 2.2e-16 ***
## Residuals  2986 3689770    1236                      
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Anova for Nonparametric Effects
##             Npar Df Npar F  Pr(F)    
## (Intercept)                          
## s(year, 4)        3  1.086 0.3537    
## s(age, 5)         4 32.380 <2e-16 ***
## education                            
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
plot(gam_mod, se = TRUE, col = "red")