SOCR ≫ | DSPA ≫ | Topics ≫ |
In previous the previous Chapter 6, Chapter 7, and Chapter 8, we covered some classification methods that use mathematical formalism to address everyday life prediction problems. In this chapter, we will focus on specific model-based statistical methods providing forecasting and classification functionality. Specifically, we will (1) demonstrate the predictive power of multiple linear regression, (2) show the foundation of regression trees and model trees, and (3) examine two complementary case-studies (Baseball Players and Heart Attack).
It may be helpful to first review Chapter 4 (Linear Algebra/Matrix Manipulations) and Chapter 6 (Introduction to Machine Learning).
Regression represents a model of a relationship between a dependent variable (value to be predicted) and a group of independent variables (predictors or features), see Chapter 6. We assume the relationships between the outcome dependent variable and the independent variables is linear.
First review the material in Chapter 4: Linear Algebra & Matrix Computing.
The straightforward case of regression is simple linear regression, which involves a single predictor. \[y=a+bx.\]
This formula would be familiar, as we showed examples in previous chapters. In this slope-intercept formula, a
is the model intercept and b
is the model slope. Thus, simple linear regression may be expressed as a bivariate equation. If we know a
and b
, for any given x
we can estimate, or predict, y
via the regression formula. If we plot x
against y
in a 2D coordinate system, where the two variables are exactly linearly related, the results will be a straight line.
However, this is the ideal case. Bivariate scatterplots using real world data may show patterns that are not necessarily precisely linear, see Chapter 2. Let’s look at a bivariate scatterplot and try to fit a simple linear regression line using two variables, e.g., hospital charges
or CHARGES
as a dependent variable, and length of stay
in the hospital or LOS
as an independent predictor. The data is available in the DSPA Data folder as CaseStudy12_AdultsHeartAttack_Data
. We can remove the pair of observations with missing values using the command heart_attack<-heart_attack[complete.cases(heart_attack), ]
.
library(plotly)
<-read.csv("https://umich.instructure.com/files/1644953/download?download_frd=1", stringsAsFactors = F)
heart_attack$CHARGES<-as.numeric(heart_attack$CHARGES)
heart_attack<-heart_attack[complete.cases(heart_attack), ]
heart_attack
<-lm(CHARGES ~ LOS, data=heart_attack)
fit1# par(cex=.8)
# plot(heart_attack$LOS, heart_attack$CHARGES, xlab="LOS", ylab = "CHARGES")
# abline(fit1, lwd=2, col="red")
plot_ly(heart_attack, x = ~LOS, y = ~CHARGES, type = 'scatter', mode = "markers", name="Data") %>%
add_trace(x=~mean(LOS), y=~mean(CHARGES), type="scatter", mode="markers",
name="(mean(LOS), mean(Charge))", marker=list(size=20, color='blue', line=list(color='yellow', width=2))) %>%
add_lines(x = ~LOS, y = fit1$fitted.values, mode = "lines", name="Linear Model") %>%
layout(title=paste0("lm(CHARGES ~ LOS), Cor(LOS,CHARGES) = ",
round(cor(heart_attack$LOS, heart_attack$CHARGES),3)))
As expected, longer hospital stays are expected to be associated with higher medical costs, or hospital charges. The scatterplot shows dots for each pair of observed measurements (\(x=LOS\) and \(y=CHARGES\)), and an increasing linear trend.
The estimated expression for this regression line is: \[\hat{y}=4582.70+212.29\times x\] or equivalently \[CHARGES=4582.70+212.29\times LOS\] Once the linear model is fit, i.e., its coefficients are estimated, we can make predictions using this explicit
regression model. Assume we have a patient that spent 10 days in hospital, then we have LOS=10
. The predicted charge is likely to be \(\$ 4582.70 + \$ 212.29 \times 10= \$ 6705.6\). Plugging x
into the expression equation automatically gives us an estimated value of the outcome y
. This chapter of the Probability and statistics EBook provides an introduction to linear modeling.
How did we get the estimated expression? The most common estimating method in statistics is ordinary least squares (OLS). OLS estimators are obtained by minimizing sum of the squared errors - that is the sum of squared vertical distance from each dot on the scatter plot to the regression line.
OLS is minimizing the following expression: \[\sum_{i=1}^{n}(y_i-\hat{y}_i)^2=\sum_{i=1}^{n}\left (\underbrace{y_i}_{\text{observed outcome}}-\underbrace{(a+b\times x_i)}_{\text{predicted outcome}}\right )^2=\sum_{i=1}^{n}\underbrace{e_i^2}_{\text{squared residual}}.\] Some calculus-based calculations suggest that the value b
minimizing the squared error is: \[b=\frac{\sum(x_i-\bar x)(y_i-\bar y)}{\sum(x_i-\bar x)^2}.\] Then, the corresponding constant term (y-intercept) a
is: \[a=\bar y-b\bar x.\]
These expressions wold become apparent if you review the material in Chapter 2. Recall that the variance is obtained by averaging sums of squared deviations (\(var(x)=\frac{1}{n}\sum^{n}_{i=1} (x_i-\mu)^2\)). When we use \(\bar{x}\) to estimate the mean of \(x\), we have the following formula for variance: \(var(x)=\frac{1}{n-1}\sum^{n}_{i=1} (x_i-\bar{x})^2\). Note that this is \(\frac{1}{n-1}\) times the denominator of b. Similar to the variance, the covariance of x and y is measuring the average sum of the deviance of x times the deviance of y: \[Cov(x, y)=\frac{1}{n}\sum^{n}_{i=1} (x_i-\mu_x)(y_i-\mu_y).\] If we utilize the sample averages (\(\bar{x}\), \(\bar{y}\)) as estimates of the corresponding population means, we have:
\[Cov(x, y)=\frac{1}{n-1}\sum^{n}_{i=1} (x_i-\bar{x})(y_i-\bar{y}).\]
This is \(\frac{1}{n-1}\) times the numerator of b. Thus, combining the above 2 expressions, we get an estimate of the slope coefficient (effect-size of LOS on Charge) expressed as: \[b=\frac{Cov(x, y)}{var(x)}.\] Let’s use the heart attack data to demonstrate these calculations.
<-cov(heart_attack$LOS, heart_attack$CHARGES)/var(heart_attack$LOS)
b b
## [1] 212.2869
<-mean(heart_attack$CHARGES)-b*mean(heart_attack$LOS)
a a
## [1] 4582.7
# compare to the lm() estimate:
$coefficients[1] fit1
## (Intercept)
## 4582.7
# we can do the same for the slope paraameter (b==fit1$coefficients[2]
We can see that this is exactly the same as previously computed estimate of the constant intercept terms using lm()
.
Regression modeling has five key assumptions:
Note: The SOCR Interactive Scatterplot Game (requires Java enabled browser) provides a dynamic interface demonstrating linear models, trends, correlations, slopes and residuals.
Based on covariance we can calculate correlation, which indicates how closely that the relationship between two variables follows a straight line. \[\rho_{x, y}=Corr(x, y)=\frac{Cov(x, y)}{\sigma_x\sigma_y}=\frac{Cov(x, y)}{\sqrt{Var(x)Var(y)}}.\] In R, correlation is given by cor()
while square root of variance or standard deviation is given by sd()
.
<-cov(heart_attack$LOS, heart_attack$CHARGES)/(sd(heart_attack$LOS)*sd(heart_attack$CHARGES))
r r
## [1] 0.2449743
cor(heart_attack$LOS, heart_attack$CHARGES)
## [1] 0.2449743
Same outputs are obtained. This correlation is a positive number that is relatively small. We can say there is a weak positive linear association between these two variables. If we have a negative number then it is a negative linear association. We have a weak association when \(0.1 \leq Cor < 0.3\), a moderate association for \(0.3 \leq Cor < 0.5\), and a strong association for \(0.5 \leq Cor \leq 1.0\). If the correlation is below \(0.1\) then it suggests little to no linear relation between the variables.
In practice, we usually have more situations with multiple predictors and one dependent variable, which may follow a multiple linear model. That is: \[y=\alpha+\beta_1x_1+\beta_2x_2+...+\beta_kx_k+\epsilon,\] or equivalently \[y=\beta_0+\beta_1x_1+\beta_2x_2+ ... +\beta_kx_k+\epsilon .\] We usually use the second notation method in statistics. This equation shows the linear relationship between k predictors and a dependent variable. In total we have k+1 coefficients to estimate.
The matrix notation for the above equation is: \[Y=X\beta+\epsilon,\] where \[Y=\left(\begin{array}{c} y_1 \\ y_2\\ ...\\ y_n \end{array}\right)\]
\[X=\left(\begin{array}{ccccc} 1 & x_{11}&x_{21}&...&x_{k1} \\ 1 & x_{12}&x_{22}&...&x_{k2} \\ .&.&.&.&.\\ 1 & x_{1n}&x_{2n}&...&x_{kn} \end{array}\right) \] \[\beta=\left(\begin{array}{c} \beta_0 \\ \beta_1\\ ...\\ \beta_k \end{array}\right)\]
and
\[\epsilon=\left(\begin{array}{c} \epsilon_1 \\ \epsilon_2\\ ...\\ \epsilon_n \end{array}\right)\] is the error term.
Similar to simple linear regression, our goal is to minimize sum of squared errors. Solving the matrix equation for \(\beta\), we get the OLS solution for the parameter vector: \[\hat{\beta}=(X^TX)^{-1}X^TY .\] The solution is presented in a matrix form, where \(X^{-1}\) and \(X^T\) are the inverse and the transpose matrices of the original design matrix \(X\).
Let’s make a function of our own using this matrix formula.
<-function(y, x){
reg<-as.matrix(x)
x<-cbind(Intercept=1, x)
xsolve(t(x)%*%x)%*%t(x)%*%y
}
We saw earlier that a clever use of matrix multiplication (%*%
) and solve()
can help with the explicit OLS solution.
Next, we will apply our function reg()
to the heart attack data. To begin with, let’s check if the simple linear regression (lm()
) output coincides with the reg()
result.
reg(y=heart_attack$CHARGES, x=heart_attack$LOS)
## [,1]
## Intercept 4582.6997
## 212.2869
fit1
##
## Call:
## lm(formula = CHARGES ~ LOS, data = heart_attack)
##
## Coefficients:
## (Intercept) LOS
## 4582.7 212.3
The results agree and we can now include additional variables as predictors. As an example, we just add age
into the model.
str(heart_attack)
## 'data.frame': 148 obs. of 8 variables:
## $ Patient : int 1 2 3 4 5 6 7 8 9 10 ...
## $ DIAGNOSIS: int 41041 41041 41091 41081 41091 41091 41091 41091 41041 41041 ...
## $ SEX : chr "F" "F" "F" "F" ...
## $ DRG : int 122 122 122 122 122 121 121 121 121 123 ...
## $ DIED : int 0 0 0 0 0 0 0 0 0 1 ...
## $ CHARGES : num 4752 3941 3657 1481 1681 ...
## $ LOS : int 10 6 5 2 1 9 15 15 2 1 ...
## $ AGE : int 79 34 76 80 55 84 84 70 76 65 ...
reg(y=heart_attack$CHARGES, x=heart_attack[, c(7, 8)])
## [,1]
## Intercept 7280.55493
## LOS 259.67361
## AGE -43.67677
# and compare the result to lm()
<-lm(CHARGES ~ LOS+AGE, data=heart_attack); fit2 fit2
##
## Call:
## lm(formula = CHARGES ~ LOS + AGE, data = heart_attack)
##
## Coefficients:
## (Intercept) LOS AGE
## 7280.55 259.67 -43.68
We utilize the mlb data “01a_data.txt”. The dataset contains 1034 records of heights and weights for some current and recent Major League Baseball (MLB) Players. These data were obtained from different resources (e.g., IBM Many Eyes).
Variables:
Let’s load this dataset first. We use as.is=T
to make non-numerical vectors into characters. Also, we delete the Name
variable because we don’t need players’ names in this case study.
<- read.table('https://umich.instructure.com/files/330381/download?download_frd=1', as.is=T, header=T)
mlbstr(mlb)
## 'data.frame': 1034 obs. of 6 variables:
## $ Name : chr "Adam_Donachie" "Paul_Bako" "Ramon_Hernandez" "Kevin_Millar" ...
## $ Team : chr "BAL" "BAL" "BAL" "BAL" ...
## $ Position: chr "Catcher" "Catcher" "Catcher" "First_Baseman" ...
## $ Height : int 74 74 72 72 73 69 69 71 76 71 ...
## $ Weight : int 180 215 210 210 188 176 209 200 231 180 ...
## $ Age : num 23 34.7 30.8 35.4 35.7 ...
<-mlb[, -1] mlb
By looking at the str()
output we notice that the variable TEAM
and Position
are misspecified as characters. To fix this we can use function as.factor()
that convert numerical or character vectors to factors.
$Team<-as.factor(mlb$Team)
mlb$Position<-as.factor(mlb$Position) mlb
The data is good to go. Let’s explore it using some summary statistics and plots.
summary(mlb$Weight)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 150.0 187.0 200.0 201.7 215.0 290.0
hist(mlb$Weight, main = "Histogram for Weights")
plot_ly(x = mlb$Weight, type = "histogram", name= "Histogram for Weights") %>%
layout(title="Baseball Players' Weight Histogram", bargap=0.1,
xaxis=list(title="Weight"), # control the y:x axes aspect ratio
yaxis = list(title="Frequency"))
The above plot illustrates our dependent variable Weight
. As we learned from Chapter 2, we know this is somewhat right-skewed.
Applying GGpairs
to obtain a compact dataset summary we can mark heavy weight and light weight players (according to \(light \lt median \lt heavy\)) by different colors in the plot:
# require(GGally)
# mlb_binary = mlb
# mlb_binary$bi_weight = as.factor(ifelse(mlb_binary$Weight>median(mlb_binary$Weight),1,0))
# g_weight <- ggpairs(data=mlb_binary[-1], title="MLB Light/Heavy Weights",
# mapping=ggplot2::aes(colour = bi_weight),
# lower=list(combo=wrap("facethist",binwidth=1)),
# # upper = list(continuous = wrap("cor", size = 4.75, alignPercent = 1))
# )
# g_weight
plot_ly(mlb) %>%
add_trace(type = 'splom', dimensions = list( list(label='Position', values=~Position),
list(label='Height', values=~Height), list(label='Weight', values=~Weight),
list(label='Age', values=~Age), list(label='Team', values=~Team)),
text=~Team,
marker = list(color = as.integer(mlb$Team),
size = 7, line = list(width = 1, color = 'rgb(230,230,230)')
)%>%
) layout(title= 'MLB Pairs Plot', hovermode='closest', dragmode= 'select',
plot_bgcolor='rgba(240,240,240, 0.95)')
# We may also mark player positions by different colors in the ggpairs plot
# g_position <- ggpairs(data=mlb[-1], title="MLB by Position",
# mapping=ggplot2::aes(colour = Position),
# lower=list(combo=wrap("facethist",binwidth=1)))
# g_position
What about our potential predictors?
table(mlb$Team)
##
## ANA ARZ ATL BAL BOS CHC CIN CLE COL CWS DET FLA HOU KC LA MIN MLW NYM NYY OAK
## 35 28 37 35 36 36 36 35 35 33 37 32 34 35 33 33 35 38 32 37
## PHI PIT SD SEA SF STL TB TEX TOR WAS
## 36 35 33 34 34 32 33 35 34 36
table(mlb$Position)
##
## Catcher Designated_Hitter First_Baseman Outfielder
## 76 18 55 194
## Relief_Pitcher Second_Baseman Shortstop Starting_Pitcher
## 315 58 52 221
## Third_Baseman
## 45
summary(mlb$Height)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 67.0 72.0 74.0 73.7 75.0 83.0
summary(mlb$Age)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 20.90 25.44 27.93 28.74 31.23 48.52
Here we have two numerical predictors, two categorical predictors and \(1034\) observations. Let’s see how R treats these three different classes of variables.
Before fitting a model, let’s examine the independence of our potential predictors and the dependent variable. Multiple linear regressions assume that predictors are all independent with each other. Is this assumption valid? As we mentioned earlier, cor()
function can answer this question in pairwise manner. Note we only look at numerical variables.
cor(mlb[c("Weight", "Height", "Age")])
## Weight Height Age
## Weight 1.0000000 0.53031802 0.15784706
## Height 0.5303180 1.00000000 -0.07367013
## Age 0.1578471 -0.07367013 1.00000000
Here we can see \(cor(y, x)=cor(x, y)\) and \(cov(x, x)=1\). Also, our Height
variable is weakly (negatively) related to the players’ age. This looks very good and wouldn’t cause any multi-collinearity problem. If two of our predictors are highly correlated, they both provide almost the same information. Then that could cause multi-collinearity. A common practice is to delete one of them in the model.
In general multivariate regression analysis, we can use the variance inflation factors (VIFs)
to detect potential multicollinearity between many covariates. The variance inflation factor (VIF) quantifies the amount of artificial inflation of the variance due to multicollinearity of the covariates. The \(VIF_l\)’s represent the expected inflation of the corresponding estimated variances. In a simple linear regression model with a single predictor \(x_l\), \(y_i=\beta_o + \beta_1 x_{i,l}+\epsilon_i,\) relative to the baseline variance, \(\sigma\), the lower bound (min) of the variance of the estimated effect-size, \(\beta_l\), is:
\[Var(\beta_l)_{min}=\frac{\sigma^2}{\sum_{i=1}^n{\left ( x_{i,l}-\bar{x}_l\right )^2}}.\]
This allows us to track the inflation of the \(\beta_l\) variance in the presence of correlated predictors in the regression model. Suppose the linear model includes \(k\) covariates with some of them multicollinear/correlated:
\[y_i=\beta_o+\beta_1x_{i,1} + \beta_2 x_{i,2} + \cdots + \underbrace{\beta_l x_{i,l}} + \cdots + \beta_k x_{i,k} +\epsilon_i.\]
Assume some of the predictors are correlated with the feature \({x_l}\), then the variance of it’s effect, \({\beta_l}\), will be inflated as follows:
\[Var(\beta_l)=\frac{\sigma^2}{\sum_{i=1}^n{\left ( x_{i,l}-\bar{x}_l\right )^2}}\times \frac{1}{1-R_l^2},\]
where \(R_l^2\) is the \(R^2\)-value computed by regressing the \(l^{th}\) feature on the remaining \((k-1)\) predictors. The stronger the linear dependence between the \(l^{th}\) feature and the remaining predictors, the larger the corresponding \(R_l^2\) value will be, and the smaller the denominator in the inflation factor (\(VIF_l\)) and the larger the variance estimate of \(\beta_l\).
The variance inflation factor (\(VIF_l\)) is the ratio of the two variances:
\[VIF_l=\frac{Var(\beta_l)}{Var(\beta_l)_{min}}= \frac{\frac{\sigma^2}{\sum_{i=1}^n{\left ( x_{i,l}-\bar{x}_l\right )^2}}\times \frac{1}{1-R_l^2}}{\frac{\sigma^2}{\sum_{i=1}^n{\left ( x_{i,l}-\bar{x}_l\right )^2}}}=\frac{1}{1-R_l^2}.\]
The regression model’s VIFs measure of how much the variance of the estimated regression coefficients, \(\beta_l\), may be “inflated” by unavoidable presence of multicollinearity among the model predictor features. \(VIF_l\sim 1\) implies that there is no multicollinearity involving the \(l^{th}\) predictor and the remaining features and hence the variance estimate of \(\beta_l\) is not inflated. On the other hand side, if \(VIF_l > 4\) suggests potential multicollinearity, and \(VIF_l> 10\) suggests serious multicollinearity requiring some model correction may be necessary as the variance estimates may be significantly biased.
We can use the function car::vif()
to compute and report the VIF factors:
::vif(lm(Weight ~ Height + Age, data=mlb)) car
## Height Age
## 1.005457 1.005457
In Chapters 16 (Feature Selection) and 17 (Controlled variable Selection), we will discuss the methods and computational strategies to identify salient features in high-dimensional datasets. Let’s briefly identify some practical approaches to address multi-colinearity problems and challenges related to large numbers of inter-dependencies in the data. Data that contains a large number of predictors is likely to include completely unrelated variables having high sample correlation. To see the nature of this problem assume we are generating a random Gaussian \(n\times k\) matrix, \(X=(X_1,X_2, \cdots, X_k)\) of \(k\) feature vectors, \(X_i, 1\leq i\leq k\), using IID standard normal random samples. Then, the expected maximum correlation between any pair of columns, \(\rho(X_{i_1},X_{i_2})\), can be large as \(k\gg n\). Even in this IID sampling problem, we still expect high rate of intrinsic and strong feature correlations. In general, this phenomenon is amplified for high-dimensional observational data, which would be expected to have high degree of colinearity. This problem presents a number of computational (e.g., function singularities and negative definite Hessian matrices), model-fitting, model-interpretation, and selection of salient predictors challenges.
There are some techniques that allow us to resolve such multi-colinearity issues in high-dimensional data. Let’s denote \(n\) to be the number of cases (samples, subjects, units, etc.) and and \(k\) be the number of features. Using divide-and-conquer, we can split the problem into two special cases:
To visualize pairwise correlations, we could use scatterplot matrix, pairs()
, ggpairs()
, or plot_ly()
.
# pairs(mlb[c("Weight", "Height", "Age")])
plot_ly(mlb) %>%
add_trace(type = 'splom', dimensions = list( list(label='Height', values=~Height),
list(label='Weight', values=~Weight), list(label='Age', values=~Age)),
text=~Position,
marker = list(color = as.integer(mlb$Team),
size = 7, line = list(width = 1, color = 'rgb(230,230,230)')
)%>%
) layout(title= 'MLB Pairs Plot', hovermode='closest', dragmode= 'select',
plot_bgcolor='rgba(240,240,240, 0.95)')
You might get a sense of it but it is difficult to see any linear pattern. We can make an more sophisticated graph using pairs.panels()
in the psych
package.
# install.packages("psych")
library(psych)
##
## Attaching package: 'psych'
## The following objects are masked from 'package:ggplot2':
##
## %+%, alpha
pairs.panels(mlb[, c("Weight", "Height", "Age")])
This plot give us much more information about the three variables. Above the diagonal, we have our correlation coefficients in numerical form. On the diagonal, there are histograms of variables. Below the diagonal, more visual information are presented to help us understand the trend. This specific graph shows us height and weight are positively and strongly correlated. Also the relationships between age and height, age and weight are very weak (horizontal red line in the below diagonal graphs indicates weak relationships).
The function we are going to use for this section is lm()
. No extra package is needed when using this function.
The lm()
function has the following components:
m<-lm(dv ~ iv, data=mydata)
OneR()
in Chapter 8. If we use .
as iv
, all of the variables, except the dependent variable (\(dv\)), are included as predictors.<- lm(Weight ~ ., data=mlb)
fit fit
##
## Call:
## lm(formula = Weight ~ ., data = mlb)
##
## Coefficients:
## (Intercept) TeamARZ
## -164.9995 7.1881
## TeamATL TeamBAL
## -1.5631 -5.3128
## TeamBOS TeamCHC
## -0.2838 0.4026
## TeamCIN TeamCLE
## 2.1051 -1.3160
## TeamCOL TeamCWS
## -3.7836 4.2944
## TeamDET TeamFLA
## 2.3024 2.6985
## TeamHOU TeamKC
## -0.6808 -4.7664
## TeamLA TeamMIN
## 2.8598 2.1269
## TeamMLW TeamNYM
## 4.2897 -1.9736
## TeamNYY TeamOAK
## 1.7483 -0.5464
## TeamPHI TeamPIT
## -6.8486 4.3023
## TeamSD TeamSEA
## 2.6133 -0.9147
## TeamSF TeamSTL
## 0.8411 -1.1341
## TeamTB TeamTEX
## -2.6616 -0.7695
## TeamTOR TeamWAS
## 1.3943 -1.7555
## PositionDesignated_Hitter PositionFirst_Baseman
## 8.9037 2.4237
## PositionOutfielder PositionRelief_Pitcher
## -6.2636 -7.7695
## PositionSecond_Baseman PositionShortstop
## -13.0843 -16.9562
## PositionStarting_Pitcher PositionThird_Baseman
## -7.3599 -4.6035
## Height Age
## 4.7175 0.8906
As we can see from the output, factors are included in the model by creating several indicators, one coefficient for each factor level (except for all reference factor levels). Each numerical variables just have one coefficient.
As we did in previous case studies, let’s examine model performance.
summary(fit)
##
## Call:
## lm(formula = Weight ~ ., data = mlb)
##
## Residuals:
## Min 1Q Median 3Q Max
## -48.692 -10.909 -0.778 9.858 73.649
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -164.9995 19.3828 -8.513 < 2e-16 ***
## TeamARZ 7.1881 4.2590 1.688 0.091777 .
## TeamATL -1.5631 3.9757 -0.393 0.694278
## TeamBAL -5.3128 4.0193 -1.322 0.186533
## TeamBOS -0.2838 4.0034 -0.071 0.943492
## TeamCHC 0.4026 3.9949 0.101 0.919749
## TeamCIN 2.1051 3.9934 0.527 0.598211
## TeamCLE -1.3160 4.0356 -0.326 0.744423
## TeamCOL -3.7836 4.0287 -0.939 0.347881
## TeamCWS 4.2944 4.1022 1.047 0.295413
## TeamDET 2.3024 3.9725 0.580 0.562326
## TeamFLA 2.6985 4.1336 0.653 0.514028
## TeamHOU -0.6808 4.0634 -0.168 0.866976
## TeamKC -4.7664 4.0242 -1.184 0.236525
## TeamLA 2.8598 4.0817 0.701 0.483686
## TeamMIN 2.1269 4.0947 0.519 0.603579
## TeamMLW 4.2897 4.0243 1.066 0.286706
## TeamNYM -1.9736 3.9493 -0.500 0.617370
## TeamNYY 1.7483 4.1234 0.424 0.671655
## TeamOAK -0.5464 3.9672 -0.138 0.890474
## TeamPHI -6.8486 3.9949 -1.714 0.086778 .
## TeamPIT 4.3023 4.0210 1.070 0.284890
## TeamSD 2.6133 4.0915 0.639 0.523148
## TeamSEA -0.9147 4.0516 -0.226 0.821436
## TeamSF 0.8411 4.0520 0.208 0.835593
## TeamSTL -1.1341 4.1193 -0.275 0.783132
## TeamTB -2.6616 4.0944 -0.650 0.515798
## TeamTEX -0.7695 4.0283 -0.191 0.848556
## TeamTOR 1.3943 4.0681 0.343 0.731871
## TeamWAS -1.7555 4.0038 -0.438 0.661142
## PositionDesignated_Hitter 8.9037 4.4533 1.999 0.045842 *
## PositionFirst_Baseman 2.4237 3.0058 0.806 0.420236
## PositionOutfielder -6.2636 2.2784 -2.749 0.006084 **
## PositionRelief_Pitcher -7.7695 2.1959 -3.538 0.000421 ***
## PositionSecond_Baseman -13.0843 2.9638 -4.415 1.12e-05 ***
## PositionShortstop -16.9562 3.0406 -5.577 3.16e-08 ***
## PositionStarting_Pitcher -7.3599 2.2976 -3.203 0.001402 **
## PositionThird_Baseman -4.6035 3.1689 -1.453 0.146613
## Height 4.7175 0.2563 18.405 < 2e-16 ***
## Age 0.8906 0.1259 7.075 2.82e-12 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 16.78 on 994 degrees of freedom
## Multiple R-squared: 0.3858, Adjusted R-squared: 0.3617
## F-statistic: 16.01 on 39 and 994 DF, p-value: < 2.2e-16
#plot(fit, which = 1:2)
plot_ly(x=fit$fitted.values, y=fit$residuals, type="scatter", mode="markers") %>%
layout(title="LM: Fitted-values vs. Model-Residuals",
xaxis=list(title="Fitted"),
yaxis = list(title="Residuals"))
The summary shows us how well does the model fits the dataset.
Residuals: This tells us about the residuals. If we have extremely large or extremely small residuals for some observations compared to the rest of residuals, either they are outliers due to reporting error or the model fits data poorly. We have \(73.649\) as our maximum and \(-48.692\) as our minimum. Their extremeness could be examined by residual diagnostic plot.
Coefficients: In this section, we look at the very right column that has stars. Stars or dots next to variables show if the variable is significant and should be included in the model. However, if nothing is next to a variable then it means this estimated covariance could be 0 in the linear model. Another thing we can look at is the Pr(>|t|)
column. A number closed to 0 in this column indicates the row variable is significant, otherwise it could be deleted from the model.
Here only some of the teams and positions are not significant. Age
and Height
are significant.
R-squared: What percent in y
is explained by included predictors. Here we have 38.58%, which indicates the model is not bad but could be improved. Usually a well-fitted linear regression would have over 70%.
The diagnostic plots also helps us understanding the situation.
Residual vs Fitted: This is the residual diagnostic plot. We can see that the residuals of observations indexed \(65\), \(160\) and \(237\) are relatively far apart from the rest. They are potential influential points or outliers.
Normal Q-Q: This plot examines the normality assumption of the model. If these dots follows the line on the graph, the normality assumption is valid. In our case, it is relatively close to the line. So, we can say that our model is valid in terms of normality.
We can employ the step
function to perform forward or backward selection of important features/predictors. It works for both lm
and glm
models. In most cases, backward-selection is preferable because it tends to retain much larger models. On the other hand, there are various criteria to evaluate a model. Commonly used criteria include AIC, BIC, Adjusted \(R^2\), etc. Let’s compare the backward and forward model selection approaches. The step
function argument direction
allows this control (default is both
, which will select the better result from either backward or forward selection). Later, in Chapter 16 and Chapter 17, we will present details about alternative feature selection approaches.
step(fit,direction = "backward")
## Start: AIC=5871.04
## Weight ~ Team + Position + Height + Age
##
## Df Sum of Sq RSS AIC
## - Team 29 9468 289262 5847.4
## <none> 279793 5871.0
## - Age 1 14090 293883 5919.8
## - Position 8 20301 300095 5927.5
## - Height 1 95356 375149 6172.3
##
## Step: AIC=5847.45
## Weight ~ Position + Height + Age
##
## Df Sum of Sq RSS AIC
## <none> 289262 5847.4
## - Age 1 14616 303877 5896.4
## - Position 8 20406 309668 5901.9
## - Height 1 100435 389697 6153.6
##
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
##
## Coefficients:
## (Intercept) PositionDesignated_Hitter
## -168.0474 8.6968
## PositionFirst_Baseman PositionOutfielder
## 2.7780 -6.0457
## PositionRelief_Pitcher PositionSecond_Baseman
## -7.7782 -13.0267
## PositionShortstop PositionStarting_Pitcher
## -16.4821 -7.3961
## PositionThird_Baseman Height
## -4.1361 4.7639
## Age
## 0.8771
step(fit,direction = "forward")
## Start: AIC=5871.04
## Weight ~ Team + Position + Height + Age
##
## Call:
## lm(formula = Weight ~ Team + Position + Height + Age, data = mlb)
##
## Coefficients:
## (Intercept) TeamARZ
## -164.9995 7.1881
## TeamATL TeamBAL
## -1.5631 -5.3128
## TeamBOS TeamCHC
## -0.2838 0.4026
## TeamCIN TeamCLE
## 2.1051 -1.3160
## TeamCOL TeamCWS
## -3.7836 4.2944
## TeamDET TeamFLA
## 2.3024 2.6985
## TeamHOU TeamKC
## -0.6808 -4.7664
## TeamLA TeamMIN
## 2.8598 2.1269
## TeamMLW TeamNYM
## 4.2897 -1.9736
## TeamNYY TeamOAK
## 1.7483 -0.5464
## TeamPHI TeamPIT
## -6.8486 4.3023
## TeamSD TeamSEA
## 2.6133 -0.9147
## TeamSF TeamSTL
## 0.8411 -1.1341
## TeamTB TeamTEX
## -2.6616 -0.7695
## TeamTOR TeamWAS
## 1.3943 -1.7555
## PositionDesignated_Hitter PositionFirst_Baseman
## 8.9037 2.4237
## PositionOutfielder PositionRelief_Pitcher
## -6.2636 -7.7695
## PositionSecond_Baseman PositionShortstop
## -13.0843 -16.9562
## PositionStarting_Pitcher PositionThird_Baseman
## -7.3599 -4.6035
## Height Age
## 4.7175 0.8906
step(fit,direction = "both")
## Start: AIC=5871.04
## Weight ~ Team + Position + Height + Age
##
## Df Sum of Sq RSS AIC
## - Team 29 9468 289262 5847.4
## <none> 279793 5871.0
## - Age 1 14090 293883 5919.8
## - Position 8 20301 300095 5927.5
## - Height 1 95356 375149 6172.3
##
## Step: AIC=5847.45
## Weight ~ Position + Height + Age
##
## Df Sum of Sq RSS AIC
## <none> 289262 5847.4
## + Team 29 9468 279793 5871.0
## - Age 1 14616 303877 5896.4
## - Position 8 20406 309668 5901.9
## - Height 1 100435 389697 6153.6
##
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
##
## Coefficients:
## (Intercept) PositionDesignated_Hitter
## -168.0474 8.6968
## PositionFirst_Baseman PositionOutfielder
## 2.7780 -6.0457
## PositionRelief_Pitcher PositionSecond_Baseman
## -7.7782 -13.0267
## PositionShortstop PositionStarting_Pitcher
## -16.4821 -7.3961
## PositionThird_Baseman Height
## -4.1361 4.7639
## Age
## 0.8771
We can observe that forward
selection retains the whole model. The better feature selection model uses backward
step-wise selection.
Both backward and forward feature selection methods utilize greedy algorithms and do not guarantees an optimal model selection result. Identifying the best feature selection requires exploring every possible combination of the predictors, which is practically not feasible, due to computational complexity associated with model selection using \(n \choose k\) combinations of features.
Alternatively, we can choose models based on various information criteria.
step(fit,k=2)
## Start: AIC=5871.04
## Weight ~ Team + Position + Height + Age
##
## Df Sum of Sq RSS AIC
## - Team 29 9468 289262 5847.4
## <none> 279793 5871.0
## - Age 1 14090 293883 5919.8
## - Position 8 20301 300095 5927.5
## - Height 1 95356 375149 6172.3
##
## Step: AIC=5847.45
## Weight ~ Position + Height + Age
##
## Df Sum of Sq RSS AIC
## <none> 289262 5847.4
## - Age 1 14616 303877 5896.4
## - Position 8 20406 309668 5901.9
## - Height 1 100435 389697 6153.6
##
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
##
## Coefficients:
## (Intercept) PositionDesignated_Hitter
## -168.0474 8.6968
## PositionFirst_Baseman PositionOutfielder
## 2.7780 -6.0457
## PositionRelief_Pitcher PositionSecond_Baseman
## -7.7782 -13.0267
## PositionShortstop PositionStarting_Pitcher
## -16.4821 -7.3961
## PositionThird_Baseman Height
## -4.1361 4.7639
## Age
## 0.8771
step(fit,k=log(nrow(mlb)))
## Start: AIC=6068.69
## Weight ~ Team + Position + Height + Age
##
## Df Sum of Sq RSS AIC
## - Team 29 9468 289262 5901.8
## <none> 279793 6068.7
## - Position 8 20301 300095 6085.6
## - Age 1 14090 293883 6112.5
## - Height 1 95356 375149 6365.0
##
## Step: AIC=5901.8
## Weight ~ Position + Height + Age
##
## Df Sum of Sq RSS AIC
## <none> 289262 5901.8
## - Position 8 20406 309668 5916.8
## - Age 1 14616 303877 5945.8
## - Height 1 100435 389697 6203.0
##
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
##
## Coefficients:
## (Intercept) PositionDesignated_Hitter
## -168.0474 8.6968
## PositionFirst_Baseman PositionOutfielder
## 2.7780 -6.0457
## PositionRelief_Pitcher PositionSecond_Baseman
## -7.7782 -13.0267
## PositionShortstop PositionStarting_Pitcher
## -16.4821 -7.3961
## PositionThird_Baseman Height
## -4.1361 4.7639
## Age
## 0.8771
\(k = 2\) yields the genuine AIC criterion, and \(k = log(n)\) refers to BIC. Let’s try to evaluate model performance again.
= step(fit,k=2,direction = "backward") fit2
## Start: AIC=5871.04
## Weight ~ Team + Position + Height + Age
##
## Df Sum of Sq RSS AIC
## - Team 29 9468 289262 5847.4
## <none> 279793 5871.0
## - Age 1 14090 293883 5919.8
## - Position 8 20301 300095 5927.5
## - Height 1 95356 375149 6172.3
##
## Step: AIC=5847.45
## Weight ~ Position + Height + Age
##
## Df Sum of Sq RSS AIC
## <none> 289262 5847.4
## - Age 1 14616 303877 5896.4
## - Position 8 20406 309668 5901.9
## - Height 1 100435 389697 6153.6
summary(fit2)
##
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
##
## Residuals:
## Min 1Q Median 3Q Max
## -49.427 -10.855 -0.344 10.110 75.301
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -168.0474 19.0351 -8.828 < 2e-16 ***
## PositionDesignated_Hitter 8.6968 4.4258 1.965 0.049679 *
## PositionFirst_Baseman 2.7780 2.9942 0.928 0.353741
## PositionOutfielder -6.0457 2.2778 -2.654 0.008072 **
## PositionRelief_Pitcher -7.7782 2.1913 -3.550 0.000403 ***
## PositionSecond_Baseman -13.0267 2.9531 -4.411 1.14e-05 ***
## PositionShortstop -16.4821 3.0372 -5.427 7.16e-08 ***
## PositionStarting_Pitcher -7.3961 2.2959 -3.221 0.001316 **
## PositionThird_Baseman -4.1361 3.1656 -1.307 0.191647
## Height 4.7639 0.2528 18.847 < 2e-16 ***
## Age 0.8771 0.1220 7.190 1.25e-12 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 16.82 on 1023 degrees of freedom
## Multiple R-squared: 0.365, Adjusted R-squared: 0.3588
## F-statistic: 58.81 on 10 and 1023 DF, p-value: < 2.2e-16
#plot(fit2, which = 1:2)
plot_ly(x=fit2$fitted.values, y=fit2$residuals, type="scatter", mode="markers") %>%
layout(title="LM: Fitted-values vs. Model-Residuals",
xaxis=list(title="Fitted"),
yaxis = list(title="Residuals"))
# compute the quantiles
<- qqplot(fit2$fitted.values, fit2$residuals, plot.it=FALSE)
QQ # take a smaller sample size to expedite the viz
<- sample(1:length(QQ$x), 1000, replace = FALSE)
ind plot_ly() %>%
add_markers(x=~QQ$x, y=~QQ$y, name="Quantiles Scatter", type="scatter", mode="markers") %>%
add_trace(x = ~c(160,260), y = ~c(-50,80), type="scatter", mode="lines",
line = list(color = "red", width = 4), name="Line", showlegend=F) %>%
layout(title='Quantile plot',
xaxis = list(title="Fitted"),
yaxis = list(title="Residuals"),
legend = list(orientation = 'h'))
Sometimes, simpler models are preferable even if there is a little bit loss of performance. In this case, we have a simpler model and \(R^2=0.365\). The whole model is still very significant. We can see that observations \(65\), \(160\) and \(237\) are relatively far from other residuals. They are potential influential points or outliers.
Also, we can observe the leverage points. Leverage points are those either outliers or influential points or both. In a regression model setting, observation leverage is the relative distance of the observation (data point) from the mean of the explanatory variable. Observations near the mean of the explanatory variable have low leverage and those far from the mean have high leverage. Yet, not all points of high leverage are necessarily influential.
# Half-normal plot for leverages
# install.packages("faraway")
library(faraway)
halfnorm(lm.influence(fit)$hat, nlab = 2, ylab="Leverages")
c(226,879),] mlb[
## Team Position Height Weight Age
## 226 NYY Designated_Hitter 75 230 36.14
## 879 SD Designated_Hitter 73 200 25.60
summary(mlb)
## Team Position Height Weight
## NYM : 38 Relief_Pitcher :315 Min. :67.0 Min. :150.0
## ATL : 37 Starting_Pitcher:221 1st Qu.:72.0 1st Qu.:187.0
## DET : 37 Outfielder :194 Median :74.0 Median :200.0
## OAK : 37 Catcher : 76 Mean :73.7 Mean :201.7
## BOS : 36 Second_Baseman : 58 3rd Qu.:75.0 3rd Qu.:215.0
## CHC : 36 First_Baseman : 55 Max. :83.0 Max. :290.0
## (Other):813 (Other) :115
## Age
## Min. :20.90
## 1st Qu.:25.44
## Median :27.93
## Mean :28.74
## 3rd Qu.:31.23
## Max. :48.52
##
A deeper discussion of variable selection, controlling the false discovery rate, is provided in Chapters 16 and 17.
In linear regression, the relationship between independent and dependent variables is assumed to be linear. However, this might not be the case. The relationship between age and weight could be quadratic, since middle-aged people might gain weight dramatically.
$age2<-(mlb$Age)^2
mlb<-lm(Weight ~ ., data=mlb)
fit2summary(fit2)
##
## Call:
## lm(formula = Weight ~ ., data = mlb)
##
## Residuals:
## Min 1Q Median 3Q Max
## -49.068 -10.775 -1.021 9.922 74.693
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -209.07068 27.49529 -7.604 6.65e-14 ***
## TeamARZ 7.41943 4.25154 1.745 0.081274 .
## TeamATL -1.43167 3.96793 -0.361 0.718318
## TeamBAL -5.38735 4.01119 -1.343 0.179552
## TeamBOS -0.06614 3.99633 -0.017 0.986799
## TeamCHC 0.14541 3.98833 0.036 0.970923
## TeamCIN 2.24022 3.98571 0.562 0.574201
## TeamCLE -1.07546 4.02870 -0.267 0.789563
## TeamCOL -3.87254 4.02069 -0.963 0.335705
## TeamCWS 4.20933 4.09393 1.028 0.304111
## TeamDET 2.66990 3.96769 0.673 0.501160
## TeamFLA 3.14627 4.12989 0.762 0.446343
## TeamHOU -0.77230 4.05526 -0.190 0.849000
## TeamKC -4.90984 4.01648 -1.222 0.221837
## TeamLA 3.13554 4.07514 0.769 0.441820
## TeamMIN 2.09951 4.08631 0.514 0.607512
## TeamMLW 4.16183 4.01646 1.036 0.300363
## TeamNYM -1.25057 3.95424 -0.316 0.751870
## TeamNYY 1.67825 4.11502 0.408 0.683482
## TeamOAK -0.68235 3.95951 -0.172 0.863212
## TeamPHI -6.85071 3.98672 -1.718 0.086039 .
## TeamPIT 4.12683 4.01348 1.028 0.304086
## TeamSD 2.59525 4.08310 0.636 0.525179
## TeamSEA -0.67316 4.04471 -0.166 0.867853
## TeamSF 1.06038 4.04481 0.262 0.793255
## TeamSTL -1.38669 4.11234 -0.337 0.736037
## TeamTB -2.44396 4.08716 -0.598 0.550003
## TeamTEX -0.68740 4.02023 -0.171 0.864270
## TeamTOR 1.24439 4.06029 0.306 0.759306
## TeamWAS -1.87599 3.99594 -0.469 0.638835
## PositionDesignated_Hitter 8.94440 4.44417 2.013 0.044425 *
## PositionFirst_Baseman 2.55100 3.00014 0.850 0.395368
## PositionOutfielder -6.25702 2.27372 -2.752 0.006033 **
## PositionRelief_Pitcher -7.68904 2.19166 -3.508 0.000471 ***
## PositionSecond_Baseman -13.01400 2.95787 -4.400 1.20e-05 ***
## PositionShortstop -16.82243 3.03494 -5.543 3.81e-08 ***
## PositionStarting_Pitcher -7.08215 2.29615 -3.084 0.002096 **
## PositionThird_Baseman -4.66452 3.16249 -1.475 0.140542
## Height 4.71888 0.25578 18.449 < 2e-16 ***
## Age 3.82295 1.30621 2.927 0.003503 **
## age2 -0.04791 0.02124 -2.255 0.024327 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 16.74 on 993 degrees of freedom
## Multiple R-squared: 0.3889, Adjusted R-squared: 0.3643
## F-statistic: 15.8 on 40 and 993 DF, p-value: < 2.2e-16
This actually brought up the overall \(R^2\) up to \(0.3889\).
As discussed earlier, middle-aged people might have a different pattern in weight increase compared to younger people. The overall pattern could be not cumulative, but rather two separate lines for young and middle-aged people. We assume 30 is the threshold. People over 30 have a steeper line for weight increase than under 30. Here we use the ifelse()
function that we mentioned in Chapter 7 to create the indicator of the threshold.
$age30<-ifelse(mlb$Age>=30, 1, 0)
mlb<-lm(Weight~Team+Position+Age+age30+Height, data=mlb)
fit3summary(fit3)
##
## Call:
## lm(formula = Weight ~ Team + Position + Age + age30 + Height,
## data = mlb)
##
## Residuals:
## Min 1Q Median 3Q Max
## -48.313 -11.166 -0.916 10.044 73.630
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -159.8884 19.8862 -8.040 2.54e-15 ***
## TeamARZ 7.4096 4.2627 1.738 0.082483 .
## TeamATL -1.4379 3.9765 -0.362 0.717727
## TeamBAL -5.3393 4.0187 -1.329 0.184284
## TeamBOS -0.1985 4.0034 -0.050 0.960470
## TeamCHC 0.4669 3.9947 0.117 0.906976
## TeamCIN 2.2124 3.9939 0.554 0.579741
## TeamCLE -1.1624 4.0371 -0.288 0.773464
## TeamCOL -3.6842 4.0290 -0.914 0.360717
## TeamCWS 4.1920 4.1025 1.022 0.307113
## TeamDET 2.4708 3.9746 0.622 0.534314
## TeamFLA 2.8563 4.1352 0.691 0.489903
## TeamHOU -0.4964 4.0659 -0.122 0.902846
## TeamKC -4.7138 4.0238 -1.171 0.241692
## TeamLA 2.9194 4.0814 0.715 0.474586
## TeamMIN 2.2885 4.0965 0.559 0.576528
## TeamMLW 4.4749 4.0269 1.111 0.266731
## TeamNYM -1.8173 3.9510 -0.460 0.645659
## TeamNYY 1.7074 4.1229 0.414 0.678867
## TeamOAK -0.3388 3.9707 -0.085 0.932012
## TeamPHI -6.6192 3.9993 -1.655 0.098220 .
## TeamPIT 4.6716 4.0332 1.158 0.247029
## TeamSD 2.8600 4.0965 0.698 0.485243
## TeamSEA -1.0121 4.0518 -0.250 0.802809
## TeamSF 1.0244 4.0545 0.253 0.800587
## TeamSTL -1.1094 4.1187 -0.269 0.787703
## TeamTB -2.4485 4.0980 -0.597 0.550312
## TeamTEX -0.6112 4.0300 -0.152 0.879485
## TeamTOR 1.3959 4.0674 0.343 0.731532
## TeamWAS -1.4189 4.0139 -0.354 0.723784
## PositionDesignated_Hitter 9.2378 4.4621 2.070 0.038683 *
## PositionFirst_Baseman 2.6074 3.0096 0.866 0.386501
## PositionOutfielder -6.0408 2.2863 -2.642 0.008367 **
## PositionRelief_Pitcher -7.5100 2.2072 -3.403 0.000694 ***
## PositionSecond_Baseman -12.8870 2.9683 -4.342 1.56e-05 ***
## PositionShortstop -16.8912 3.0406 -5.555 3.56e-08 ***
## PositionStarting_Pitcher -7.0825 2.3099 -3.066 0.002227 **
## PositionThird_Baseman -4.4307 3.1719 -1.397 0.162773
## Age 0.6904 0.2153 3.207 0.001386 **
## age30 2.2636 1.9749 1.146 0.251992
## Height 4.7113 0.2563 18.380 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 16.77 on 993 degrees of freedom
## Multiple R-squared: 0.3866, Adjusted R-squared: 0.3619
## F-statistic: 15.65 on 40 and 993 DF, p-value: < 2.2e-16
This model performs worse than the quadratic model in terms of \(R^2\). Moreover, age30
is not significant. So, we will stick with the quadratic model.
So far, each feature’s individual effect has considered in our model. It is possible that features act in pairs to affect our independent variable. Let’s examine that deeper.
Interaction is a combined effect by two features. If we are not sure whether two features interact, we could test by adding an interaction term into the model. If the interaction is significant then we confirmed an interaction between two features.
<-lm(Weight~Team+Height+Age*Position+age2, data=mlb)
fit4summary(fit4)
##
## Call:
## lm(formula = Weight ~ Team + Height + Age * Position + age2,
## data = mlb)
##
## Residuals:
## Min 1Q Median 3Q Max
## -48.761 -11.049 -0.761 9.911 75.533
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -199.15403 29.87269 -6.667 4.35e-11 ***
## TeamARZ 8.10376 4.26339 1.901 0.0576 .
## TeamATL -0.81743 3.97899 -0.205 0.8373
## TeamBAL -4.64820 4.03972 -1.151 0.2502
## TeamBOS 0.37698 4.00743 0.094 0.9251
## TeamCHC 0.33104 3.99507 0.083 0.9340
## TeamCIN 2.56023 3.99603 0.641 0.5219
## TeamCLE -0.66254 4.03154 -0.164 0.8695
## TeamCOL -3.72098 4.03759 -0.922 0.3570
## TeamCWS 4.63266 4.10884 1.127 0.2598
## TeamDET 3.21380 3.98231 0.807 0.4199
## TeamFLA 3.56432 4.14902 0.859 0.3905
## TeamHOU -0.38733 4.07249 -0.095 0.9242
## TeamKC -4.66678 4.02384 -1.160 0.2464
## TeamLA 3.51766 4.09400 0.859 0.3904
## TeamMIN 2.31585 4.10502 0.564 0.5728
## TeamMLW 4.34793 4.02501 1.080 0.2803
## TeamNYM -0.28505 3.98537 -0.072 0.9430
## TeamNYY 1.87847 4.12774 0.455 0.6491
## TeamOAK -0.23791 3.97729 -0.060 0.9523
## TeamPHI -6.25671 3.99545 -1.566 0.1177
## TeamPIT 4.18719 4.01944 1.042 0.2978
## TeamSD 2.97028 4.08838 0.727 0.4677
## TeamSEA -0.07220 4.05922 -0.018 0.9858
## TeamSF 1.35981 4.07771 0.333 0.7388
## TeamSTL -1.23460 4.11960 -0.300 0.7645
## TeamTB -1.90885 4.09592 -0.466 0.6413
## TeamTEX -0.31570 4.03146 -0.078 0.9376
## TeamTOR 1.73976 4.08565 0.426 0.6703
## TeamWAS -1.43933 4.00274 -0.360 0.7192
## Height 4.70632 0.25646 18.351 < 2e-16 ***
## Age 3.32733 1.37088 2.427 0.0154 *
## PositionDesignated_Hitter -44.82216 30.68202 -1.461 0.1444
## PositionFirst_Baseman 23.51389 20.23553 1.162 0.2455
## PositionOutfielder -13.33140 15.92500 -0.837 0.4027
## PositionRelief_Pitcher -16.51308 15.01240 -1.100 0.2716
## PositionSecond_Baseman -26.56932 20.18773 -1.316 0.1884
## PositionShortstop -27.89454 20.72123 -1.346 0.1786
## PositionStarting_Pitcher -2.44578 15.36376 -0.159 0.8736
## PositionThird_Baseman -10.20102 23.26121 -0.439 0.6611
## age2 -0.04201 0.02170 -1.936 0.0531 .
## Age:PositionDesignated_Hitter 1.77289 1.00506 1.764 0.0780 .
## Age:PositionFirst_Baseman -0.71111 0.67848 -1.048 0.2949
## Age:PositionOutfielder 0.24147 0.53650 0.450 0.6527
## Age:PositionRelief_Pitcher 0.30374 0.50564 0.601 0.5482
## Age:PositionSecond_Baseman 0.46281 0.68281 0.678 0.4981
## Age:PositionShortstop 0.38257 0.70998 0.539 0.5901
## Age:PositionStarting_Pitcher -0.17104 0.51976 -0.329 0.7422
## Age:PositionThird_Baseman 0.18968 0.79561 0.238 0.8116
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 16.73 on 985 degrees of freedom
## Multiple R-squared: 0.3945, Adjusted R-squared: 0.365
## F-statistic: 13.37 on 48 and 985 DF, p-value: < 2.2e-16
Here we can see the overall \(R^2\) improved and some of the interactions are significant under 0.1 level.
As we saw in Chapter 8, a decision tree builds by multiple if-else
logical decisions and can classify observations. We could add regression into decision trees so that a decision tree can make numerical predictions.
Numeric prediction trees are built in the same way as classification trees. Recall the discussion in Chapter 8 where the data are partitioned first via a divide-and-conquer strategy based on features. The homogeneity of the resulting classification trees is measured by various metrics, e.g., entropy. In regression-tree prediction, node homogeneity is measured by various statistics such as variance, standard deviation or absolute deviation from the mean.
A common splitting criterion for decision trees is the standard deviation reduction (SDR).
\[SDR=sd(T)-\sum_{i=1}^n \left | \frac{T_i}{T} \right | \times sd(T_i),\]
where sd(T)
is the standard deviation for the original data. After the summation of all segments, \(|\frac{T_i}{T}|\) is the proportion of observations in \(i^{th}\) segment compared to total number of observations, and \(sd(T_i)\) is the standard deviation for the \(i^{th}\) segment.
An example for this would be: \[Original\, data:\{1, 2, 3, 3, 4, 5, 6, 6, 7, 8\}\] \[Split\, method\, 1:\{1, 2, 3|3, 4, 5, 6, 6, 7, 8\}\] \[Split\, method\, 2:\{1, 2, 3, 3, 4, 5|6, 6, 7, 8\}\] In split method 1, \(T_1=\{1, 2, 3\}\), \(T_2=\{3, 4, 5, 6, 6, 7, 8\}\). In split method 2, \(T_1=\{1, 2, 3, 3, 4, 5\}\), \(T_2=\{6, 6, 7, 8\}\).
<-c(1, 2, 3, 3, 4, 5, 6, 6, 7, 8)
ori<-c(1, 2, 3)
at1<-c(3, 4, 5, 6, 6, 7, 8)
at2<-c(1, 2, 3, 3, 4, 5)
bt1<-c(6, 6, 7, 8)
bt2<-sd(ori)-(length(at1)/length(ori)*sd(at1)+length(at2)/length(ori)*sd(at2))
sdr_a<-sd(ori)-(length(bt1)/length(ori)*sd(bt1)+length(bt2)/length(ori)*sd(bt2))
sdr_b sdr_a
## [1] 0.7702557
sdr_b
## [1] 1.041531
The method length()
is used above to get the number of elements in a specific vector.
Larger SDR indicates greater reduction in standard deviation after splitting. Here we have split method 2 yielding greater SDR, so the tree splitting decision would prefer the second method, which is expected to produce more homogeneous sub-sets (children nodes), compared to method 1.
Now, the tree will be split under bt1
and bt2
following same rules (greater SDR wins). Assume we cannot split further (bt1
and bt2
are terminal nodes). The observations classified into bt1
will be predicted with \(mean(bt1)=3\) and those classified as bt2
with \(mean(bt2)=6.75\).
Bayesian Additive Regression Trees (BART) represent sums of regression trees models that rely on boosting the constituent Bayesian regularized trees.
The R packages BayesTree
and BART
provide computational implementation of fitting BART models to data. In supervised setting where \(x\) and \(y\) represent the predictors and the outcome, the BART model is mathematically represented as: \[y=f(x)+\epsilon = \sum_{j=1}^m{f_j(x)} +\epsilon.\] More specifically, \[y_i=f(x_i)+\epsilon = \sum_{j=1}^m{f_j(x_i)} +\epsilon,\ \forall 1\leq i\leq n.\] The residuals are typically assumed to be white noise, \(\epsilon_i \sim N(o,\sigma^2), \ iid\). The function \(f\) represents the boosted ensemble of weaker regression trees, \(f_i=g( . | T_j, M_j)\), where \(T_j\) and \(M_j\) represent the \(j^{th}\) tree and the set of values, \(\mu_{k,j}\), assigned to each terminal node \(k\) in \(T_j\), respectively.
The BART model may be estimated via Gibbs sampling, e.g., Bayesian backfitting Markov chain Monte Carlo (MCMC) algorithm. For instance, iteratively sampling \((T_j ,M_j)\) and \(\sigma\), conditional on all other variables, \((x,y)\), for each \(j\) until meeting certain convergence criterion. Given \(\sigma\), conditional sampling of \((T_j ,M_j)\) may be accomplished via the partial residual \[\epsilon_{j_o} = \sum_{i=1}^n{\left (y_i - \sum_{j\not= j_o}^m{g( x_i | T_j, M_j)} \right )}.\]
For prediction on new data, \(X\), data-driven priors on \(\sigma\) and the parameters defining \((T_j ,M_j)\) may be used to allow sampling a model from the posterior distribution: \[p\left (\sigma, \{(T_1 ,M_1), (T_2 ,M_2), ..., (T_2 ,M_2)\} | X, Y \right )=\] \[= p(\sigma) \prod_{(tree)j=1}^m{\left ( \left ( \prod_{(node)k} {p(\mu_{k,j}|T_j)} \right )p(T_j) \right )} .\] In this posterior factorization, model regularization is achieved by four criteria:
Using the BART model to forecast a response corresponding to newly observed data \(x\) is achieved by using one individual (or ensembling/averaging multiple) prediction models near algorithmic convergence.
The BART algorithm involves three steps:
This example illustrates a simple 1D BART simulation, based on \(h(x)=x^3\sin(x)\), where we can nicely plot the performance of the BART classifier.
# simulate training data
= 0.2 # sigma
sig = function(x) {
func return (sin(x) * (x^(3)))
}
set.seed(1234)
= 300
n = sort(2*runif(n)-1) # define the input
x = func(x) + sig * rnorm(n) # define the output
y
# xtest: values we want to estimate func(x) at; this is also our prior prediction for y.
= seq(-pi, pi, by=0.2)
xtest
##plot simulated data
# plot(x, y, cex=1.0, col="gray")
# points(xtest, rep(0, length(xtest)), col="blue", pch=1, cex=1.5)
# legend("top", legend=c("Data", "(Conjugate) Flat Prior"),
# col=c("gray", "blue"), lwd=c(2,2), lty=c(3,3), bty="n", cex=0.9, seg.len=3)
plot_ly(x=x, y=y, type="scatter", mode="markers", name="data") %>%
add_trace(x=xtest, y=rep(0, length(xtest)), mode="markers", name="prior") %>%
layout(title='(Conjugate) Flat Prior',
xaxis = list(title="X", range = c(-1.3, 1.3)),
yaxis = list(title="Y", range = c(-0.6, 1.6)),
legend = list(orientation = 'h'))
# run the weighted BART (BART::wbart) on the simulated data
# install.packages("BART")
library(BART)
set.seed(1234) # set seed for reproducibility of MCMC
# nskip=Number of burn-in MCMC iterations; ndpost=number of posterior draws to return
<- wbart(x.train=as.data.frame(x), y.train=y, x.test=as.data.frame(xtest),
model_bart nskip=300, ndpost=1000, printevery=1000)
## *****Into main of wbart
## *****Data:
## data:n,p,np: 300, 1, 32
## y1,yn: 0.572918, 1.073126
## x1,x[n*p]: -0.993435, 0.997482
## xp1,xp[np*p]: -3.141593, 3.058407
## *****Number of Trees: 200
## *****Number of Cut Points: 100 ... 100
## *****burn and ndpost: 300, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,0.032889,3.000000,0.018662
## *****sigma: 0.309524
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,1,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1300)
## done 1000 (out of 1300)
## time: 3s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
# result is a list containing the BART run
# explore the BART model fit
names(model_bart)
## [1] "sigma" "yhat.train.mean" "yhat.train" "yhat.test.mean"
## [5] "yhat.test" "varcount" "varprob" "treedraws"
## [9] "proc.time" "mu" "varcount.mean" "varprob.mean"
## [13] "rm.const"
dim(model_bart$yhat.test)
## [1] 1000 32
# The (l,j) element of the matrix `yhat.test` represents the l^{th} draw of `func` evaluated at the j^{th} value of x.test
# A matrix with ndpost rows and nrow(x.train) columns. Each row corresponds to a draw f* from the posterior of f
# and each column corresponds to a row of x.train. The (i,j) value is f*(x) for the i^th kept draw of f
# and the j^th row of x.train
# # plot the data, the BART Fit, and the uncertainty
# plot(x, y, cex=1.2, cex.axis=0.8, cex.lab=0.8, mgp=c(1.3,.3,0), tcl=-0.2, col="gray")
# lines(xtest, func(xtest), col="blue", lty=1, lwd=2)
# lines(xtest, apply(model_bart$yhat.test, 2, mean), col="green", lwd=2, lty=2) # show the mean of f(x_j)
# quant_marg = apply(model_bart$yhat.test, 2, quantile, probs=c(0.025, 0.975)) # plot the 2.5% and 97.5% quantiles
# lines(xtest, quant_marg[1,], col="red", lty=1, lwd=2)
# lines(xtest, quant_marg[2,], col="red", lty=1,lwd=2)
# legend("top", legend=c("Data", "True Signal","Posterior Mean","95% CI"),
# col=c("black", "blue","green","red"), lwd=c(2, 2,2,2), lty=c(3, 1,1,1), bty="n", cex=0.9, seg.len=3)
= apply(model_bart$yhat.test, 2, quantile, probs=c(0.025, 0.975)) # plot the 2.5% and 97.5% quantiles
quant_marg plot_ly(x=x, y=y, type="scatter", mode="markers", name="Data") %>%
add_trace(x=xtest, y=func(xtest), mode="markers", name="True Signal") %>%
add_trace(x=xtest, y=apply(model_bart$yhat.test, 2, mean), mode="lines", name="Posterior Mean") %>%
add_trace(x=xtest, y=apply(model_bart$yhat.test, 2, quantile, probs=c(0.025, 0.975)),
mode="markers", name="95% CI") %>%
add_trace(x=xtest, y=quant_marg[1,], mode="lines", name="Lower Band") %>%
add_trace(x=xtest, y=quant_marg[2,], mode="lines", name="Upper Band") %>%
layout(title='BART Model (n=300)',
xaxis = list(title="X", range = c(-1.3, 1.3)),
yaxis = list(title="Y", range = c(-0.6, 1.6)),
legend = list(orientation = 'h'))
names(model_bart)
## [1] "sigma" "yhat.train.mean" "yhat.train" "yhat.test.mean"
## [5] "yhat.test" "varcount" "varprob" "treedraws"
## [9] "proc.time" "mu" "varcount.mean" "varprob.mean"
## [13] "rm.const"
dim(model_bart$yhat.train)
## [1] 1000 300
summary(model_bart$yhat.train.mean-apply(model_bart$yhat.train, 2, mean))
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -4.441e-16 -8.327e-17 6.939e-18 -8.130e-18 9.714e-17 5.551e-16
summary(model_bart$yhat.test.mean-apply(model_bart$yhat.test, 2, mean))
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -2.776e-16 1.535e-16 3.331e-16 2.661e-16 4.441e-16 4.441e-16
# yhat.train(test).mean: Average the draws to get the estimate of the posterior mean of func(x)
If we increase the sample size, n, the computational complexity increases and the BART model bounds should get tighter.
= 3000 # 300 --> 3,000
n set.seed(1234)
= sort(2*runif(n)-1)
x = func(x) + sig*rnorm(n)
y
<- wbart(x.train=as.data.frame(x), y.train=y, x.test=as.data.frame(xtest),
model_bart_2 nskip=300, ndpost=1000, printevery=1000)
## *****Into main of wbart
## *****Data:
## data:n,p,np: 3000, 1, 32
## y1,yn: 1.021931, 0.626872
## x1,x[n*p]: -0.999316, 0.998606
## xp1,xp[np*p]: -3.141593, 3.058407
## *****Number of Trees: 200
## *****Number of Cut Points: 100 ... 100
## *****burn and ndpost: 300, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,0.035536,3.000000,0.017828
## *****sigma: 0.302529
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,1,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1300)
## done 1000 (out of 1300)
## time: 15s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
# plot(x, y, cex=1.2, cex.axis=0.8, cex.lab=0.8, mgp=c(1.3,.3,0), tcl=-0.2, col="gray")
# lines(xtest, func(xtest), col="blue", lty=1, lwd=2)
# lines(xtest, apply(model_bart_2$yhat.test, 2, mean), col="green", lwd=2, lty=2) # show the mean of f(x_j)
# quant_marg = apply(model_bart_2$yhat.test, 2, quantile, probs=c(0.025, 0.975)) # plot the 2.5% and 97.5% CI
# lines(xtest, quant_marg[1,], col="red", lty=1, lwd=2)
# lines(xtest, quant_marg[2,], col="red", lty=1,lwd=2)
# legend("top",legend=c("Data (n=3,000)", "True Signal", "Posterior Mean","95% CI"),
# col=c("gray", "blue","green","red"), lwd=c(2, 2,2,2), lty=c(3, 1,1,1), bty="n", cex=0.9, seg.len=3)
= apply(model_bart_2$yhat.test, 2, quantile, probs=c(0.025, 0.975)) # plot the 2.5% and 97.5% CI
quant_marg2 plot_ly(x=x, y=y, type="scatter", mode="markers", name="Data") %>%
add_trace(x=xtest, y=func(xtest), mode="markers", name="True Signal") %>%
add_trace(x=xtest, y=apply(model_bart_2$yhat.test, 2, mean), mode="lines", name="Posterior Mean") %>%
add_trace(x=xtest, y=apply(model_bart_2$yhat.test, 2, quantile, probs=c(0.025, 0.975)),
mode="markers", name="95% CI") %>%
add_trace(x=xtest, y=quant_marg2[1,], mode="lines", name="Lower Band") %>%
add_trace(x=xtest, y=quant_marg2[2,], mode="lines", name="Upper Band") %>%
layout(title='BART Model (n=3,000)',
xaxis = list(title="X", range = c(-1.3, 1.3)),
yaxis = list(title="Y", range = c(-0.6, 1.6)),
legend = list(orientation = 'h'))
In this second example, we will simulate \(n=5,000\) cases and \(p=20\) features.
# simulate data
set.seed(1234)
=5000; p=20
n= 3*(1:p)/p
beta =1.0
sig= matrix(rnorm(n*p), ncol=p) # design matrix)
X = 10 + X %*% matrix(beta, ncol=1) + sig*rnorm(n) # outcome
y =as.double(y)
y
=100000
np= matrix(rnorm(np*p), ncol=p)
Xp set.seed(1234)
<- system.time(model_bart_MD <- wbart(x.train=as.data.frame(X), y.train=y, x.test=as.data.frame(Xp),
t1 nkeeptrain=200, nkeeptest=100, nkeeptestmean=500, nkeeptreedraws=100, printevery=1000)
)
## *****Into main of wbart
## *****Data:
## data:n,p,np: 5000, 20, 100000
## y1,yn: -6.790191, 5.425931
## x1,x[n*p]: -1.207066, -3.452091
## xp1,xp[np*p]: -0.396587, -0.012973
## *****Number of Trees: 200
## *****Number of Cut Points: 100 ... 100
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,1.049775,3.000000,0.189323
## *****sigma: 0.985864
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,20,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 200,100,500,100
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 5,10,2,10
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 163s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 200,100,500,100
dim(model_bart_MD$yhat.train)
## [1] 200 5000
dim(model_bart_MD$yhat.test)
## [1] 100 100000
names(model_bart_MD$treedraws)
## [1] "cutpoints" "trees"
# str(model_bart_MD$treedraws$trees)
# The trees are stored in a long character string and there are 100 draws each consisting of 200 trees.
# To predict using the Multi-Dimensional BART model (MD)
<- system.time({pred_model_bart_MD2 <- predict(model_bart_MD, as.data.frame(Xp), mc.cores=6)}) t2
## *****In main of C++ for bart prediction
## tc (threadcount): 6
## number of bart draws: 100
## number of trees in bart sum: 200
## number of x columns: 20
## from x,np,p: 20, 100000
## ***using serial code
dim(pred_model_bart_MD2)
## [1] 100 100000
t1
## user system elapsed
## 163.16 0.06 163.61
t2
## user system elapsed
## 44.20 0.08 44.35
# pred_model_bart_MD2 has row dimension equal to the number of kept tree draws (100) and
# column dimension equal to the number of row in Xp (100,000).
# Compare the BART predictions using 1K trees vs. 100 kept trees (very similar results)
# plot(model_bart_MD$yhat.test.mean, apply(pred_model_bart_MD2, 2, mean),
# xlab="BART Prediction using 1,000 Trees", ylab="BART Prediction using 100 Kept Trees")
# abline(0,1, col="red", lwd=2)
plot_ly() %>%
add_trace(x = c(-30,50), y = c(-30,50), type="scatter", mode="lines",
line = list(width = 4), name="Consistent BART Prediction (1,000 vs. 100 Trees)") %>%
add_markers(x=model_bart_MD$yhat.test.mean, y=apply(pred_model_bart_MD2, 2, mean),
name="BART Prediction Mean Estimates", type="scatter", mode="markers") %>%
layout(title='Scatter of BART Predictions (1,000 vs. 100 Trees)',
xaxis = list(title="BART Prediction (1,000 Trees)"),
yaxis = list(title="BART Prediction (100 Trees)"),
legend = list(orientation = 'h'))
# Compare BART Prediction to a linear fit
= lm(y ~ ., data.frame(X,y))
lm_func = predict(lm_func, data.frame(Xp))
pred_lm
# plot(pred_lm, model_bart_MD$yhat.test.mean, xlab="Linear Model Predictions", ylab="BART Predictions",
# cex=0.5, cex.axis=1.0, cex.lab=0.8, mgp=c(1.3,.3,0), tcl=-0.2)
# abline(0,1, col="red", lwd=2)
plot_ly() %>%
add_markers(x=pred_lm, y=model_bart_MD$yhat.test.mean,
name="Consistent LM/BART Prediction", type="scatter", mode="markers") %>%
add_trace(x = c(-30,50), y = c(-30,50), type="scatter", mode="lines",
line = list(width = 4), name="LM vs. BART Prediction") %>%
layout(title='Scatter of Linear Model vs. BART Predictions',
xaxis = list(title="Linear Model Prediction"),
yaxis = list(title="BART Prediction"),
legend = list(orientation = 'h'))
Let’s use BART to model the heart attack dataset (CaseStudy12_ AdultsHeartAttack_Data.csv). The data includes about 150 observations and 8 features, including hospital charges (CHARGES
), which will be used as a response variable.
<-read.csv("https://umich.instructure.com/files/1644953/download?download_frd=1",
heart_attackstringsAsFactors = F)
str(heart_attack)
## 'data.frame': 150 obs. of 8 variables:
## $ Patient : int 1 2 3 4 5 6 7 8 9 10 ...
## $ DIAGNOSIS: int 41041 41041 41091 41081 41091 41091 41091 41091 41041 41041 ...
## $ SEX : chr "F" "F" "F" "F" ...
## $ DRG : int 122 122 122 122 122 121 121 121 121 123 ...
## $ DIED : int 0 0 0 0 0 0 0 0 0 1 ...
## $ CHARGES : chr "4752" "3941" "3657" "1481" ...
## $ LOS : int 10 6 5 2 1 9 15 15 2 1 ...
## $ AGE : int 79 34 76 80 55 84 84 70 76 65 ...
# convert the CHARGES (independent variable) to numerical form.
# NA's are created so let's remain only the complete cases
$CHARGES <- as.numeric(heart_attack$CHARGES)
heart_attack<- heart_attack[complete.cases(heart_attack), ]
heart_attack $gender <- ifelse(heart_attack$SEX=="F", 1, 0)
heart_attack<- heart_attack[, -c(1,2,3)]
heart_attack dim(heart_attack); colnames(heart_attack)
## [1] 148 6
## [1] "DRG" "DIED" "CHARGES" "LOS" "AGE" "gender"
<- as.matrix(heart_attack[ , -3]) # x training, excluding the charges (output)
x.train = heart_attack$CHARGES # y=output for modeling (BART, lm, lasso, etc.)
y.train
# Data should be standardized for all model-based predictions (e.g., lm, lasso/glmnet), but
# this is not critical for BART
# We'll just do some random train/test splits and report the out of sample performance of BART and lasso
<- function(y, yhat) {
RMSE return(sqrt(mean((y-yhat)^2)))
}
<- 10 # number of train/test splits (ala CV validation)
nd <- length(y.train)
n <- floor(0.8*n) # 80:20 train:test split each time
ntrain
<- rep(0, nd) # initialize BART and LASSO RMSE vectors
RMSE_BART <- rep(0, nd)
RMSE_LASSO
<- matrix(0.0, n-ntrain,nd) # Initialize the BART and LASSO out-of-sample predictions
pred_BART <- matrix(0.0, n-ntrain,nd) pred_LASSO
In Chapter 17, we will learn more about LASSO regularized linear modeling. Now, let’s use the glmnet::glmnet()
method to fit a LASSO model and compare it to BART using the Heart Attack hospitalization case-study.
library(glmnet)
for(i in 1:nd) {
set.seed(1234*i)
# train/test split index
<- sample(1:n, ntrain)
train_ind
# Outcome (CHARGES)
<- y.train[train_ind]; yTest <- y.train[-train_ind]
yTrain
# Features for BART
<- x.train[train_ind, ]; xBTest <- x.train[-train_ind, ]
xBTrain
# Features for LASSO (scale)
<- apply(x.train[train_ind, ], 2, scale)
xLTrain <- apply(x.train[-train_ind, ], 2, scale)
xLTest
# BART: parallel version of mc.wbart, same arguments as in wbart
# model_BART <- mc.wbart(xBTrain, yTrain, xBTest, mc.cores=6, keeptrainfits=FALSE)
<- wbart(xBTrain, yTrain, xBTest, printevery=1000)
model_BART
# LASSO
<- cv.glmnet(xLTrain, yTrain, family="gaussian", standardize=TRUE)
cv_LASSO <- cv_LASSO$lambda.min
best_lambda <- glmnet(xLTrain, yTrain, family="gaussian", lambda=c(best_lambda), standardize=TRUE)
model_LASSO
#get predictions on testing data
<- model_BART$yhat.test.mean
pred_BART1 <- predict(model_LASSO, xLTest, s=best_lambda, type="response")[, 1]
pred_LASSO1
#store results
<- RMSE(yTest, pred_BART1); pred_BART[, i] <- pred_BART1
RMSE_BART[i] <- RMSE(yTest, pred_LASSO1); pred_LASSO[, i] <- pred_LASSO1;
RMSE_LASSO[i] }
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 2105.135593, -2641.864407
## x1,x[n*p]: 121.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2228477.133405
## *****sigma: 3382.354604
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: -3870.449153, 5404.550847
## x1,x[n*p]: 122.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 0.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2261599.844262
## *****sigma: 3407.398506
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 2659.855932, -1749.144068
## x1,x[n*p]: 121.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 0.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2182151.820220
## *****sigma: 3347.013986
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 714.000000, -4103.000000
## x1,x[n*p]: 121.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2310883.132530
## *****sigma: 3444.324312
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 3229.305085, 1884.305085
## x1,x[n*p]: 122.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 0.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2169347.152131
## *****sigma: 3337.179552
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 2115.474576, 1998.474576
## x1,x[n*p]: 121.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2209447.645085
## *****sigma: 3367.882283
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: -1255.525424, -3624.525424
## x1,x[n*p]: 123.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 0.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2177441.668674
## *****sigma: 3343.399788
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 3795.457627, -388.542373
## x1,x[n*p]: 121.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,288.534922,3.000000,1972180.926129
## *****sigma: 3181.913893
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 3s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: -3943.779661, -524.779661
## x1,x[n*p]: 123.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2184631.918632
## *****sigma: 3348.915451
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 5245.610169, 4149.610169
## x1,x[n*p]: 121.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,289.984491,3.000000,2281761.566623
## *****sigma: 3422.552953
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
Plot BART vs. LASSO predictions.
# compare the out of sample RMSE measures
# qqplot(RMSE_BART, RMSE_LASSO)
# abline(0, 1, col="red", lwd=2)
plot_ly() %>%
add_markers(x=RMSE_BART, y=RMSE_LASSO,
name="", type="scatter", mode="markers") %>%
add_trace(x = c(2800,4000), y = c(2800,4000), type="scatter", mode="lines",
line = list(width = 4), name="") %>%
layout(title='Scatter of Linear Model vs. BART RMSE',
xaxis = list(title="RMSE (BART)"),
yaxis = list(title="RMSE (Linear Model)"),
legend = list(orientation = 'h'))
# Next compare the out of sample predictions
<- lm(B ~ L, data.frame(B=as.double(pred_BART), L=as.double(pred_LASSO)))
model_lm <- c(2800,9000)
x1 <- 2373.0734628 + 0.5745497*x1
y1
# plot(as.double(pred_BART), as.double(pred_LASSO),xlab="BART Predictions",ylab="LASSO Predictions", col="gray")
# abline(0, 1, col="red", lwd=2)
# abline(model_lm$coef, col="blue", lty=2, lwd=3)
# legend("topleft",legend=c("Scatterplot Predictions (BART vs. LASSO)", "Ideal Agreement", "LM (BART~LASSO)"),
# col=c("gray", "red","blue"), lwd=c(2,2,2), lty=c(3,1,1), bty="n", cex=0.9, seg.len=3)
plot_ly() %>%
add_markers(x=as.double(pred_BART), y=as.double(pred_LASSO),
name="BART Predictions vs. Observed Scatter", type="scatter", mode="markers") %>%
add_trace(x = c(2800,9000), y = c(2800,9000), type="scatter", mode="lines",
line = list(width = 4), name="Ideal Agreement") %>%
add_trace(x = x1, y = y1, type="scatter", mode="lines",
line = list(width = 4), name="LM (BART ~ LASSO)") %>%
layout(title='Scatterplot Predictions (BART vs. LASSO)',
xaxis = list(title="BART Predictions"),
yaxis = list(title="LASSO Predictions"),
legend = list(orientation = 'h'))
If the default prior estimate (sigest
of the error variance (\(\sigma^2\)) is inverted chi-squared, i.e., standard conditionally conjugate prior) yields reasonable results, we can try longer BART runs (ndpost=5000
). Mind the stable distribution of the \(\hat{\sigma}^2\) (y-axis) with respect to the number of posterior draws (x-axis).
<- wbart(x.train, y.train, nskip=1000, ndpost=5000, printevery=5000) model_BART_long
## *****Into main of wbart
## *****Data:
## data:n,p,np: 148, 5, 0
## y1,yn: -712.837838, 2893.162162
## x1,x[n*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 1000, 5000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2178098.906176
## *****sigma: 3343.904335
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 5000,5000,5000,5000
## *****printevery: 5000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
##
## MCMC
## done 0 (out of 6000)
## done 5000 (out of 6000)
## time: 11s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 5000,0,0,5000
# plot(model_BART_long$sigma, xlab="Number of Posterior Draws Returned")
plot_ly() %>%
add_markers(x=c(1:length(model_BART_long$sigma)), y=model_BART_long$sigma,
name="BART vs. LASSO Scatter", type="scatter", mode="markers") %>%
layout(title='Scatterplot BART Sigma (post burn in draws of sigma)',
xaxis = list(title="Number of Posterior Draws Returned"),
yaxis = list(title="model_BART_long$sigma"),
legend = list(orientation = 'h'))
# plot(model_BART_long$yhat.train.mean, y.train, xlab="BART Predicted Charges", ylab="Observed Charges",
# main=sprintf("Correlation (Observed,Predicted)=%f",
# round(cor(model_BART_long$yhat.train.mean, y.train), digits=2)))
# abline(0, 1, col="red", lty=2)
# legend("topleft",legend=c("BART Predictions", "LM (BART~LASSO)"),
# col=c("gray", "red","blue"), lwd=c(2,2,2), lty=c(3,1,1), bty="n", cex=0.9, seg.len=3)
plot_ly() %>%
add_markers(x=model_BART_long$yhat.train.mean, y=y.train,
name="BART vs. LASSO Scatter", type="scatter", mode="markers") %>%
add_trace(x = c(2800,9000), y = c(2800,9000), type="scatter", mode="lines",
line = list(width = 4), name="LM (BART~LASSO)") %>%
layout(title=sprintf("Observed vs. BART-Predicted Hospital Charges: Cor(BART,Observed)=%f",
round(cor(model_BART_long$yhat.train.mean, y.train), digits=2)),
xaxis = list(title="BART Predictions"),
yaxis = list(title="Observed Values"),
legend = list(orientation = 'h'))
<- order(model_BART_long$yhat.train.mean)
ind
# boxplot(model_BART_long$yhat.train[ , ind], ylim=range(y.train), xlab="case",
# ylab="BART Hospitalization Charge Prediction Range")
<- paste0("Case",rownames(heart_attack))
caseIDs <- paste0("", c(1:dim(model_BART_long$yhat.train)[1]))
rowIDs colnames(model_BART_long$yhat.train) <- caseIDs
rownames(model_BART_long$yhat.train) <- rowIDs
<- as.data.frame(model_BART_long$yhat.train[ , ind])
df1 <- as.data.frame(cbind(index=c(1:dim(model_BART_long$yhat.train)[1]), df1))
df2_wide # colnames(df2_wide); dim(df2_wide)
<- tidyr::gather(df2_wide, case, measurement, Case138:Case8)
df_long # str(df_long )
# 'data.frame': 74000 obs. of 3 variables:
# $ index : int 1 2 3 4 5 6 7 8 9 10 ...
# $ case : chr "Case138" "Case138" "Case138" "Case138" ...
# $ measurement: num 5013 3958 4604 2602 2987 ...
<- as.data.frame(cbind(cases=caseIDs, value=y.train))
actualCharges
plot_ly() %>%
add_trace(data=df_long, y = ~measurement, color = ~case, type = "box") %>%
add_trace(x=~actualCharges$cases, y=~actualCharges$value, type="scatter", mode="markers",
name="Observed Charge", marker=list(size=20, color='green', line=list(color='yellow', width=2))) %>%
layout(title="Box-and-whisker Plots across all 148 Cases (Highlighted True Charges)",
xaxis = list(title="Cases"),
yaxis = list(title="BART Hospitalization Charge Prediction Range"),
showlegend=F)