Credit: This note heavily uses material from the books An Introduction to Statistical Learning: with Applications in R (ISL2) and Elements of Statistical Learning: Data Mining, Inference, and Prediction (ESL2).
Display system information for reproducibility.
sessionInfo()
## R version 4.2.3 (2023-03-15)
## Platform: aarch64-apple-darwin20 (64-bit)
## Running under: macOS Ventura 13.3.1
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## loaded via a namespace (and not attached):
## [1] digest_0.6.31 R6_2.5.1 jsonlite_1.8.4 evaluate_0.21
## [5] cachem_1.0.8 rlang_1.1.1 cli_3.6.1 rstudioapi_0.14
## [9] jquerylib_0.1.4 bslib_0.4.2 rmarkdown_2.21 tools_4.2.3
## [13] xfun_0.39 yaml_2.3.7 fastmap_1.1.1 compiler_4.2.3
## [17] htmltools_0.5.5 knitr_1.42 sass_0.4.6
The truth is never linear! Or almost never!
But often the linearity assumption is good enough.
When it’s not …
offer a lot of flexibility, without losing the ease and interpretability of linear models.
wage
vs age
:
library(gtsummary)
library(ISLR2)
library(tidyverse)
# Convert to tibble
Wage <- as_tibble(Wage) %>% print(width = Inf)
## # A tibble: 3,000 × 11
## year age maritl race education region
## <int> <int> <fct> <fct> <fct> <fct>
## 1 2006 18 1. Never Married 1. White 1. < HS Grad 2. Middle Atlantic
## 2 2004 24 1. Never Married 1. White 4. College Grad 2. Middle Atlantic
## 3 2003 45 2. Married 1. White 3. Some College 2. Middle Atlantic
## 4 2003 43 2. Married 3. Asian 4. College Grad 2. Middle Atlantic
## 5 2005 50 4. Divorced 1. White 2. HS Grad 2. Middle Atlantic
## 6 2008 54 2. Married 1. White 4. College Grad 2. Middle Atlantic
## 7 2009 44 2. Married 4. Other 3. Some College 2. Middle Atlantic
## 8 2008 30 1. Never Married 3. Asian 3. Some College 2. Middle Atlantic
## 9 2006 41 1. Never Married 2. Black 3. Some College 2. Middle Atlantic
## 10 2004 52 2. Married 1. White 2. HS Grad 2. Middle Atlantic
## jobclass health health_ins logwage wage
## <fct> <fct> <fct> <dbl> <dbl>
## 1 1. Industrial 1. <=Good 2. No 4.32 75.0
## 2 2. Information 2. >=Very Good 2. No 4.26 70.5
## 3 1. Industrial 1. <=Good 1. Yes 4.88 131.
## 4 2. Information 2. >=Very Good 1. Yes 5.04 155.
## 5 2. Information 1. <=Good 1. Yes 4.32 75.0
## 6 2. Information 2. >=Very Good 1. Yes 4.85 127.
## 7 1. Industrial 2. >=Very Good 1. Yes 5.13 170.
## 8 2. Information 1. <=Good 1. Yes 4.72 112.
## 9 2. Information 2. >=Very Good 1. Yes 4.78 119.
## 10 2. Information 2. >=Very Good 1. Yes 4.86 129.
## # ℹ 2,990 more rows
# Summary statistics
Wage %>% tbl_summary()
Characteristic | N = 3,0001 |
---|---|
year | |
2003 | 513 (17%) |
2004 | 485 (16%) |
2005 | 447 (15%) |
2006 | 392 (13%) |
2007 | 386 (13%) |
2008 | 388 (13%) |
2009 | 389 (13%) |
age | 42 (34, 51) |
maritl | |
1. Never Married | 648 (22%) |
2. Married | 2,074 (69%) |
3. Widowed | 19 (0.6%) |
4. Divorced | 204 (6.8%) |
5. Separated | 55 (1.8%) |
race | |
1. White | 2,480 (83%) |
2. Black | 293 (9.8%) |
3. Asian | 190 (6.3%) |
4. Other | 37 (1.2%) |
education | |
1. < HS Grad | 268 (8.9%) |
2. HS Grad | 971 (32%) |
3. Some College | 650 (22%) |
4. College Grad | 685 (23%) |
5. Advanced Degree | 426 (14%) |
region | |
1. New England | 0 (0%) |
2. Middle Atlantic | 3,000 (100%) |
3. East North Central | 0 (0%) |
4. West North Central | 0 (0%) |
5. South Atlantic | 0 (0%) |
6. East South Central | 0 (0%) |
7. West South Central | 0 (0%) |
8. Mountain | 0 (0%) |
9. Pacific | 0 (0%) |
jobclass | |
1. Industrial | 1,544 (51%) |
2. Information | 1,456 (49%) |
health | |
1. <=Good | 858 (29%) |
2. >=Very Good | 2,142 (71%) |
health_ins | |
1. Yes | 2,083 (69%) |
2. No | 917 (31%) |
logwage | 4.65 (4.45, 4.86) |
wage | 105 (85, 129) |
1 n (%); Median (IQR) |
# Plot wage ~ age, GAM fit is display when n >1000
Wage %>%
ggplot(mapping = aes(x = age, y = wage)) +
geom_point() +
geom_smooth() +
labs(title = "Wage changes nonlinearly with age",
x = "Age",
y = "Wage (k$)")
\[ y_i = \beta_0 + \beta_1 x_i + \beta_2 x_i^2 + \cdots + \beta_d x_i^d + \epsilon_i. \]
# Plot wage ~ age, display order-4 polynomial fit
Wage %>%
ggplot(mapping = aes(x = age, y = wage)) +
geom_point() +
geom_smooth(
method = "lm",
formula = y ~ poly(x, degree = 4)
) +
labs(
title = "Degree-4 Polynomial",
x = "Age",
y = "Wage (k$)"
)
Create new variables \(X_1 = X\), \(X_2 = X^2\), …, and then treat as multiple linear regression.
Not really interested in the coefficients; more interested in the fitted function values at any value \(x_0\): \[ \hat f(x_0) = \hat{\beta}_0 + \hat{\beta}_1 x_0 + \hat{\beta}_2 x_0^2 + \hat{\beta}_3 x_0^3 + \hat{\beta}_4 x_0^4. \]
# poly(age, 4) constructs orthogonal polynomial of degree 1 to degree, all orthogonal to the constant
lmod <- lm(wage ~ poly(age, degree = 4), data = Wage)
summary(lmod)
##
## Call:
## lm(formula = wage ~ poly(age, degree = 4), data = Wage)
##
## Residuals:
## Min 1Q Median 3Q Max
## -98.707 -24.626 -4.993 15.217 203.693
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 111.7036 0.7287 153.283 < 2e-16 ***
## poly(age, degree = 4)1 447.0679 39.9148 11.201 < 2e-16 ***
## poly(age, degree = 4)2 -478.3158 39.9148 -11.983 < 2e-16 ***
## poly(age, degree = 4)3 125.5217 39.9148 3.145 0.00168 **
## poly(age, degree = 4)4 -77.9112 39.9148 -1.952 0.05104 .
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 39.91 on 2995 degrees of freedom
## Multiple R-squared: 0.08626, Adjusted R-squared: 0.08504
## F-statistic: 70.69 on 4 and 2995 DF, p-value: < 2.2e-16
# poly(age, 4, raw = TRUE) procudes raw othogonal polynomial, which match Python
lmod <- lm(wage ~ poly(age, degree = 4, raw = TRUE), data = Wage)
summary(lmod)
##
## Call:
## lm(formula = wage ~ poly(age, degree = 4, raw = TRUE), data = Wage)
##
## Residuals:
## Min 1Q Median 3Q Max
## -98.707 -24.626 -4.993 15.217 203.693
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -1.842e+02 6.004e+01 -3.067 0.002180 **
## poly(age, degree = 4, raw = TRUE)1 2.125e+01 5.887e+00 3.609 0.000312 ***
## poly(age, degree = 4, raw = TRUE)2 -5.639e-01 2.061e-01 -2.736 0.006261 **
## poly(age, degree = 4, raw = TRUE)3 6.811e-03 3.066e-03 2.221 0.026398 *
## poly(age, degree = 4, raw = TRUE)4 -3.204e-05 1.641e-05 -1.952 0.051039 .
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 39.91 on 2995 degrees of freedom
## Multiple R-squared: 0.08626, Adjusted R-squared: 0.08504
## F-statistic: 70.69 on 4 and 2995 DF, p-value: < 2.2e-16
:::
Since \(\hat f(x_0)\) is a linear function of the \(\hat{\beta}_j\), we can get a simple expression for pointwise-variances \(\operatorname{Var}[\hat f(x_0)]\) at any value \(x_0\).
We either fix the degree \(d\) at some reasonably low value, or use cross-validation to choose \(d\).
Can do separately on several variables. Just stack the variables into one matrix, and separate out the pieces afterwards (see GAMs later).
Polynomial modeling can be done for generalized linear models (logistic regression, Poisson regression, etc) as well.
Caveat: polynomials have notorious tail behavior. Very bad for extrapolation.
library(splines)
# Plot wage ~ age
Wage %>%
ggplot(mapping = aes(x = age, y = wage)) +
geom_point(alpha = 0.25) +
# Polynomial regression with degree 14
geom_smooth(
method = "lm",
formula = y ~ poly(x, degree = 14),
color = "blue"
) +
# Natural cubic spline
geom_smooth(
method = "lm",
formula = y ~ ns(x, df = 14),
color = "red"
) +
labs(
title = "Natural cubic spline (red) vs polynomial regression (blue)",
subtitle = "Both have df=15",
x = "Age",
y = "Wage (k$)"
)
Instead of a single polynomial in \(X\) over its whole domain, we can rather use different polynomials in regions defined by knots. E.g., a piecewise cubic polynomial with a single knot at \(c\) takes the form \[ y_i = \begin{cases} \beta_{01} + \beta_{11} x_i + \beta_{21} x_i^2 + \beta_{31} x_i^3 + \epsilon_i & \text{if } x_i < c \\ \beta_{02} + \beta_{12} x_i + \beta_{22} x_i^2 + \beta_{32} x_i^3 + \epsilon_i & \text{if } x_i \ge c \end{cases}. \]
Better to add constraints to the polynomials, e.g., continuity.
Splines have the “maximum” amount of continuity.
A linear spline with knots at \(\xi_k\), \(k = 1,\ldots,K\), is a piecewise linear polynomial continuous at each knot.
We can represent this model as \[
y_i = \beta_0 + \beta_1 b_1(x_i) + \beta_2 b_2(x_i) + \cdots +
\beta_{K+1} b_{K+1}(x_i) + \epsilon_i,
\] where \(b_k\) are
basis functions:
\[\begin{eqnarray*}
b_1(x_i) &=& x_i \\
b_{k+1}(x_i) &=& (x_i - \xi_k)_+, \quad k=1,\ldots,K.
\end{eqnarray*}\] Here \((\cdot)_k\) means positive part \[
(x_i - \xi_i)_+ = \begin{cases}
x_i - \xi_k & \text{if } x_i > \xi_k \\
0 & \text{otherwise}
\end{cases}.
\]
A cubic spline with knots at \(\xi_k\), \(k = 1,\ldots,K\), is a piecewise cubic polynomial with continuous derivatives up to order 2 at each knot.
Again we can represent this model with truncated power basis functions \[ y_i = \beta_0 + \beta_1 b_1(x_i) + \beta_2 b_2(x_i) + \cdots + \beta_{K+3} b_{K+3}(x_i) + \epsilon_i, \] with \[\begin{eqnarray*} b_1(x_i) &=& x_i \\ b_2(x_i) &=& x_i^2 \\ b_3(x_i) &=& x_i^3 \\ b_{k+3}(x_i) &=& (x_i - \xi_k)_+^3, \quad k = 1,\ldots,K, \end{eqnarray*}\] where \[ (x_i - \xi_i)_+^3 = \begin{cases} (x_i - \xi_k)^3 & \text{if } x_i > \xi_k \\ 0 & \text{otherwise} \end{cases}. \]
A cubic spline with \(K\) knots costs \(K+4\) parameters or degrees of freedom. That is \(4(K+1)\) polynomial coefficients minus \(3K\) constraints.
While the truncated power basis is conceptually simple, it is not too attractive numerically: powers of large numbers can lead to severe rounding problems. In practice, B-spline basis functions are preferred for their computational efficiency. See ESL Chapter 5 Appendix.
Splines can have high variance at the outer range of the predictors.
A natural cubic spline extrapolates linearly beyond the boundary knots. This adds \(4 = 2 \times 2\) extra constraints, and allows us to put more internal knots for the same degrees of freedom as a regular cubic spline.
A natural spline with \(K\) knots has \(K\) degrees of freedom.
library(splines)
# Plot wage ~ age
Wage %>%
ggplot(mapping = aes(x = age, y = wage)) +
geom_point(alpha = 0.25) +
# Cubic spline
geom_smooth(
method = "lm",
formula = y ~ bs(x, knots = c(25, 40, 60)),
color = "blue"
) +
# Natural cubic spline
geom_smooth(
method = "lm",
formula = y ~ ns(x, knots = c(25, 40, 60)),
color = "red"
) +
labs(
title = "Natural cubic spline fit (red) vs cubic spline fit (blue)",
x = "Age",
y = "Wage (k$)"
)
One strategy is to decide \(K\), the number of knots, and then place them at appropriate quantiles of the observed \(X\).
In practice users often specify the degree of freedom and let software choose the number of knots and locations.
Consider this criterion for fitting a smooth function \(g(x)\) to some data: \[ \text{minimize} \quad \sum_{i=1}^n (y_i - g(x_i))^2 + \lambda \int g''(t)^2 \, dt. \]
The solution is a (shrunken) natural cubic spline, with a knot at every unique value of \(x_i\). The roughness penalty still controls the roughness via \(\lambda\).
Smoothing splines avoid the knot-selection issue, leaving a single \(\lambda\) to be chosen.
The vector of \(n\) fitted values can be written as \(\hat{g}_\lambda = S_\lambda y\), where \(S_{\lambda}\) is an \(n \times n\) matrix (determined by the \(x_i\) and \(\lambda\)).
The effective degrees of freedom are given by
\[
\text{df}_{\lambda} = \sum_{i=1}^n S_{\lambda,ii}.
\] Thus we can specify df
rather than \(\lambda\).
The leave-one-out (LOO) cross-validated error is given by \[ \text{RSS}_{\text{CV}}(\lambda) = \sum_{i=1}^n \left[ \frac{y_i - \hat{g}_\lambda(x_i)}{1 - S_{\lambda,ii}} \right]^2. \]
ggformula
package supplies geom_spline
function for displaying smoothing spline fits.
library(ggformula)
## Loading required package: ggstance
##
## Attaching package: 'ggstance'
## The following objects are masked from 'package:ggplot2':
##
## geom_errorbarh, GeomErrorbarh
## Loading required package: scales
##
## Attaching package: 'scales'
## The following object is masked from 'package:purrr':
##
## discard
## The following object is masked from 'package:readr':
##
## col_factor
## Loading required package: ggridges
##
## New to ggformula? Try the tutorials:
## learnr::run_tutorial("introduction", package = "ggformula")
## learnr::run_tutorial("refining", package = "ggformula")
library(splines)
# Plot wage ~ age
Wage %>%
ggplot(mapping = aes(x = age, y = wage)) +
geom_point(alpha = 0.25) +
# Smoothing spline with df = 16
geom_spline(
df = 16,
color = "red"
) +
# Smoothing spline with GCV tuned df
geom_spline(
# df = 6.8,
cv = TRUE,
color = "blue"
) +
labs(
title = "Smoothing spline with df=16 (red) vs LOOCV tuned df=6.8 (blue)",
x = "Age",
y = "Wage (k$)"
)
## Warning in smooth.spline(data$x, data$y, w = weight, spar = spar, cv = cv, :
## cross-validation with non-unique 'x' values seems doubtful
With a sliding weight function, we fit separate linear fits over the range of \(X\) by weighted least squares.
At \(X=x_0\), \[ \text{minimize} \quad \sum_{i=1}^n K(x_i, x_0) (y_i - \beta_0 - \beta_1 x_i)^2, \] where \(K\) is a weighting function that assigns heavier weight for \(x_i\) close to \(x_0\) and zero weight for points furthest from \(x_0\).
Locally weighted linear regression:
loess
function in R and lowess
in
Python.
Anecdotally, loess gives better appearance, but is \(O(N^2)\) in memory, so does not work for larger data sets.
Generalized additive models (GAMs) allows for flexible nonlinearities in several variables, but retains the additive structure of linear models. \[ y_i = \beta_0 + f_1(x_{i1}) + f_2(x_{i2}) + \cdots + f_p (x_{ip}) + \epsilon_i. \]
We can fit GAM simply using, e.g. natural splines.
Coefficients not that interesting; fitted functions are.
Can mix terms: some linear, some nonlinear, and use ANOVA to compare models.
Can use smoothing splines or local regression as well. In R:
gam(wage ~ s(year; df = 5) + lo(age; span = :5) + education)
.
GAMs are additive, although low-order interactions can be
included in a natural way using, e.g. bivariate smoothers or
interactions of the form (in R)
ns(age, df = 5):ns(year, df = 5)
.
Natural splines for year
and age
.
gam_mod <- lm(
wage ~ ns(year, df = 4) + ns(age, df = 5) + education,
data = Wage
)
summary(gam_mod)
##
## Call:
## lm(formula = wage ~ ns(year, df = 4) + ns(age, df = 5) + education,
## data = Wage)
##
## Residuals:
## Min 1Q Median 3Q Max
## -120.513 -19.608 -3.583 14.112 214.535
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 46.949 4.704 9.980 < 2e-16 ***
## ns(year, df = 4)1 8.625 3.466 2.488 0.01289 *
## ns(year, df = 4)2 3.762 2.959 1.271 0.20369
## ns(year, df = 4)3 8.127 4.211 1.930 0.05375 .
## ns(year, df = 4)4 6.806 2.397 2.840 0.00455 **
## ns(age, df = 5)1 45.170 4.193 10.771 < 2e-16 ***
## ns(age, df = 5)2 38.450 5.076 7.575 4.78e-14 ***
## ns(age, df = 5)3 34.239 4.383 7.813 7.69e-15 ***
## ns(age, df = 5)4 48.678 10.572 4.605 4.31e-06 ***
## ns(age, df = 5)5 6.557 8.367 0.784 0.43328
## education2. HS Grad 10.983 2.430 4.520 6.43e-06 ***
## education3. Some College 23.473 2.562 9.163 < 2e-16 ***
## education4. College Grad 38.314 2.547 15.042 < 2e-16 ***
## education5. Advanced Degree 62.554 2.761 22.654 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 35.16 on 2986 degrees of freedom
## Multiple R-squared: 0.293, Adjusted R-squared: 0.2899
## F-statistic: 95.2 on 13 and 2986 DF, p-value: < 2.2e-16
Smoothing splines for year
and age
.
library(gam)
## Loading required package: foreach
##
## Attaching package: 'foreach'
## The following objects are masked from 'package:purrr':
##
## accumulate, when
## Loaded gam 1.22-2
gam_mod <- gam(
wage ~ s(year, 4) + s(age, 5) + education,
data = Wage
)
summary(gam_mod)
##
## Call: gam(formula = wage ~ s(year, 4) + s(age, 5) + education, data = Wage)
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -119.43 -19.70 -3.33 14.17 213.48
##
## (Dispersion Parameter for gaussian family taken to be 1235.69)
##
## Null Deviance: 5222086 on 2999 degrees of freedom
## Residual Deviance: 3689770 on 2986 degrees of freedom
## AIC: 29887.75
##
## Number of Local Scoring Iterations: NA
##
## Anova for Parametric Effects
## Df Sum Sq Mean Sq F value Pr(>F)
## s(year, 4) 1 27162 27162 21.981 2.877e-06 ***
## s(age, 5) 1 195338 195338 158.081 < 2.2e-16 ***
## education 4 1069726 267432 216.423 < 2.2e-16 ***
## Residuals 2986 3689770 1236
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Anova for Nonparametric Effects
## Npar Df Npar F Pr(F)
## (Intercept)
## s(year, 4) 3 1.086 0.3537
## s(age, 5) 4 32.380 <2e-16 ***
## education
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
plot(gam_mod, se = TRUE, col = "red")