SOCR ≫ DSPA ≫ Topics ≫

1 Motivation

In previous the previous Chapter 6, Chapter 7, and Chapter 8, we covered some classification methods that use mathematical formalism to address everyday life prediction problems. In this chapter, we will focus on specific model-based statistical methods providing forecasting and classification functionality. Specifically, we will (1) demonstrate the predictive power of multiple linear regression, (2) show the foundation of regression trees and model trees, and (3) examine two complementary case-studies (Baseball Players and Heart Attack).

It may be helpful to first review Chapter 4 (Linear Algebra/Matrix Manipulations) and Chapter 6 (Introduction to Machine Learning).

2 Understanding Regression

Regression represents a model of a relationship between a dependent variable (value to be predicted) and a group of independent variables (predictors or features), see Chapter 6. We assume the relationships between the outcome dependent variable and the independent variables is linear.

2.1 Simple linear regression

First review the material in Chapter 4: Linear Algebra & Matrix Computing.

The straightforward case of regression is simple linear regression, which involves a single predictor. \[y=a+bx.\]

This formula would be familiar, as we showed examples in previous chapters. In this slope-intercept formula, a is the model intercept and b is the model slope. Thus, simple linear regression may be expressed as a bivariate equation. If we know a and b, for any given x we can estimate, or predict, y via the regression formula. If we plot x against y in a 2D coordinate system, where the two variables are exactly linearly related, the results will be a straight line.

However, this is the ideal case. Bivariate scatterplots using real world data may show patterns that are not necessarily precisely linear, see Chapter 2. Let’s look at a bivariate scatterplot and try to fit a simple linear regression line using two variables, e.g., hospital charges or CHARGES as a dependent variable, and length of stay in the hospital or LOS as an independent predictor. The data is available in the DSPA Data folder as CaseStudy12_AdultsHeartAttack_Data. We can remove the pair of observations with missing values using the command heart_attack<-heart_attack[complete.cases(heart_attack), ].

library(plotly)
heart_attack<-read.csv("https://umich.instructure.com/files/1644953/download?download_frd=1", stringsAsFactors = F)
heart_attack$CHARGES<-as.numeric(heart_attack$CHARGES)
heart_attack<-heart_attack[complete.cases(heart_attack), ]

fit1<-lm(CHARGES ~ LOS, data=heart_attack)
# par(cex=.8)
# plot(heart_attack$LOS, heart_attack$CHARGES, xlab="LOS", ylab = "CHARGES")
# abline(fit1, lwd=2, col="red")

plot_ly(heart_attack, x = ~LOS, y = ~CHARGES, type = 'scatter', mode = "markers", name="Data") %>%
    add_trace(x=~mean(LOS), y=~mean(CHARGES), type="scatter", mode="markers",
            name="(mean(LOS), mean(Charge))", marker=list(size=20, color='blue', line=list(color='yellow', width=2))) %>%
    add_lines(x = ~LOS, y = fit1$fitted.values, mode = "lines", name="Linear Model") %>%
    layout(title=paste0("lm(CHARGES ~ LOS), Cor(LOS,CHARGES) = ", 
                        round(cor(heart_attack$LOS, heart_attack$CHARGES),3)))

As expected, longer hospital stays are expected to be associated with higher medical costs, or hospital charges. The scatterplot shows dots for each pair of observed measurements (\(x=LOS\) and \(y=CHARGES\)), and an increasing linear trend.

The estimated expression for this regression line is: \[\hat{y}=4582.70+212.29\times x\] or equivalently \[CHARGES=4582.70+212.29\times LOS\] Once the linear model is fit, i.e., its coefficients are estimated, we can make predictions using this explicit regression model. Assume we have a patient that spent 10 days in hospital, then we have LOS=10. The predicted charge is likely to be \(\$ 4582.70 + \$ 212.29 \times 10= \$ 6705.6\). Plugging x into the expression equation automatically gives us an estimated value of the outcome y. This chapter of the Probability and statistics EBook provides an introduction to linear modeling.

2.2 Ordinary least squares estimation

How did we get the estimated expression? The most common estimating method in statistics is ordinary least squares (OLS). OLS estimators are obtained by minimizing sum of the squared errors - that is the sum of squared vertical distance from each dot on the scatter plot to the regression line.

OLS is minimizing the following expression: \[\sum_{i=1}^{n}(y_i-\hat{y}_i)^2=\sum_{i=1}^{n}\left (\underbrace{y_i}_{\text{observed outcome}}-\underbrace{(a+b\times x_i)}_{\text{predicted outcome}}\right )^2=\sum_{i=1}^{n}\underbrace{e_i^2}_{\text{squared residual}}.\] Some calculus-based calculations suggest that the value b minimizing the squared error is: \[b=\frac{\sum(x_i-\bar x)(y_i-\bar y)}{\sum(x_i-\bar x)^2}.\] Then, the corresponding constant term (y-intercept) a is: \[a=\bar y-b\bar x.\]

These expressions wold become apparent if you review the material in Chapter 2. Recall that the variance is obtained by averaging sums of squared deviations (\(var(x)=\frac{1}{n}\sum^{n}_{i=1} (x_i-\mu)^2\)). When we use \(\bar{x}\) to estimate the mean of \(x\), we have the following formula for variance: \(var(x)=\frac{1}{n-1}\sum^{n}_{i=1} (x_i-\bar{x})^2\). Note that this is \(\frac{1}{n-1}\) times the denominator of b. Similar to the variance, the covariance of x and y is measuring the average sum of the deviance of x times the deviance of y: \[Cov(x, y)=\frac{1}{n}\sum^{n}_{i=1} (x_i-\mu_x)(y_i-\mu_y).\] If we utilize the sample averages (\(\bar{x}\), \(\bar{y}\)) as estimates of the corresponding population means, we have:
\[Cov(x, y)=\frac{1}{n-1}\sum^{n}_{i=1} (x_i-\bar{x})(y_i-\bar{y}).\]

This is \(\frac{1}{n-1}\) times the numerator of b. Thus, combining the above 2 expressions, we get an estimate of the slope coefficient (effect-size of LOS on Charge) expressed as: \[b=\frac{Cov(x, y)}{var(x)}.\] Let’s use the heart attack data to demonstrate these calculations.

b<-cov(heart_attack$LOS, heart_attack$CHARGES)/var(heart_attack$LOS)
b
## [1] 212.2869
a<-mean(heart_attack$CHARGES)-b*mean(heart_attack$LOS)
a
## [1] 4582.7
# compare to the lm() estimate:
fit1$coefficients[1]
## (Intercept) 
##      4582.7
# we can do the same for the slope paraameter (b==fit1$coefficients[2]

We can see that this is exactly the same as previously computed estimate of the constant intercept terms using lm().

2.3 Model Assumptions

Regression modeling has five key assumptions:

2.4 Correlations

Note: The SOCR Interactive Scatterplot Game (requires Java enabled browser) provides a dynamic interface demonstrating linear models, trends, correlations, slopes and residuals.

Based on covariance we can calculate correlation, which indicates how closely that the relationship between two variables follows a straight line. \[\rho_{x, y}=Corr(x, y)=\frac{Cov(x, y)}{\sigma_x\sigma_y}=\frac{Cov(x, y)}{\sqrt{Var(x)Var(y)}}.\] In R, correlation is given by cor() while square root of variance or standard deviation is given by sd().

r<-cov(heart_attack$LOS, heart_attack$CHARGES)/(sd(heart_attack$LOS)*sd(heart_attack$CHARGES))
r
## [1] 0.2449743
cor(heart_attack$LOS, heart_attack$CHARGES)
## [1] 0.2449743

Same outputs are obtained. This correlation is a positive number that is relatively small. We can say there is a weak positive linear association between these two variables. If we have a negative number then it is a negative linear association. We have a weak association when \(0.1 \leq Cor < 0.3\), a moderate association for \(0.3 \leq Cor < 0.5\), and a strong association for \(0.5 \leq Cor \leq 1.0\). If the correlation is below \(0.1\) then it suggests little to no linear relation between the variables.

2.5 Multiple Linear Regression

In practice, we usually have more situations with multiple predictors and one dependent variable, which may follow a multiple linear model. That is: \[y=\alpha+\beta_1x_1+\beta_2x_2+...+\beta_kx_k+\epsilon,\] or equivalently \[y=\beta_0+\beta_1x_1+\beta_2x_2+ ... +\beta_kx_k+\epsilon .\] We usually use the second notation method in statistics. This equation shows the linear relationship between k predictors and a dependent variable. In total we have k+1 coefficients to estimate.

The matrix notation for the above equation is: \[Y=X\beta+\epsilon,\] where \[Y=\left(\begin{array}{c} y_1 \\ y_2\\ ...\\ y_n \end{array}\right)\]

\[X=\left(\begin{array}{ccccc} 1 & x_{11}&x_{21}&...&x_{k1} \\ 1 & x_{12}&x_{22}&...&x_{k2} \\ .&.&.&.&.\\ 1 & x_{1n}&x_{2n}&...&x_{kn} \end{array}\right) \] \[\beta=\left(\begin{array}{c} \beta_0 \\ \beta_1\\ ...\\ \beta_k \end{array}\right)\]

and

\[\epsilon=\left(\begin{array}{c} \epsilon_1 \\ \epsilon_2\\ ...\\ \epsilon_n \end{array}\right)\] is the error term.

Similar to simple linear regression, our goal is to minimize sum of squared errors. Solving the matrix equation for \(\beta\), we get the OLS solution for the parameter vector: \[\hat{\beta}=(X^TX)^{-1}X^TY .\] The solution is presented in a matrix form, where \(X^{-1}\) and \(X^T\) are the inverse and the transpose matrices of the original design matrix \(X\).

Let’s make a function of our own using this matrix formula.

reg<-function(y, x){
  x<-as.matrix(x)
  x<-cbind(Intercept=1, x)
  solve(t(x)%*%x)%*%t(x)%*%y
}

We saw earlier that a clever use of matrix multiplication (%*%) and solve() can help with the explicit OLS solution.

Next, we will apply our function reg() to the heart attack data. To begin with, let’s check if the simple linear regression (lm()) output coincides with the reg() result.

reg(y=heart_attack$CHARGES, x=heart_attack$LOS)
##                [,1]
## Intercept 4582.6997
##            212.2869
fit1
## 
## Call:
## lm(formula = CHARGES ~ LOS, data = heart_attack)
## 
## Coefficients:
## (Intercept)          LOS  
##      4582.7        212.3

The results agree and we can now include additional variables as predictors. As an example, we just add age into the model.

str(heart_attack)
## 'data.frame':    148 obs. of  8 variables:
##  $ Patient  : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ DIAGNOSIS: int  41041 41041 41091 41081 41091 41091 41091 41091 41041 41041 ...
##  $ SEX      : chr  "F" "F" "F" "F" ...
##  $ DRG      : int  122 122 122 122 122 121 121 121 121 123 ...
##  $ DIED     : int  0 0 0 0 0 0 0 0 0 1 ...
##  $ CHARGES  : num  4752 3941 3657 1481 1681 ...
##  $ LOS      : int  10 6 5 2 1 9 15 15 2 1 ...
##  $ AGE      : int  79 34 76 80 55 84 84 70 76 65 ...
reg(y=heart_attack$CHARGES, x=heart_attack[, c(7, 8)])
##                 [,1]
## Intercept 7280.55493
## LOS        259.67361
## AGE        -43.67677
# and compare the result to lm()
fit2<-lm(CHARGES ~ LOS+AGE, data=heart_attack); fit2
## 
## Call:
## lm(formula = CHARGES ~ LOS + AGE, data = heart_attack)
## 
## Coefficients:
## (Intercept)          LOS          AGE  
##     7280.55       259.67       -43.68

3 Case Study 1: Baseball Players

3.1 Step 1 - collecting data

We utilize the mlb data “01a_data.txt”. The dataset contains 1034 records of heights and weights for some current and recent Major League Baseball (MLB) Players. These data were obtained from different resources (e.g., IBM Many Eyes).

Variables:

  • Name: MLB Player Name
  • Team: The Baseball team the player was a member of at the time the data was acquired
  • Position: Player field position
  • Height: Player height in inch
  • Weight: Player weight in pounds
  • Age: Player age at time of record.

3.2 Step 2 - exploring and preparing the data

Let’s load this dataset first. We use as.is=T to make non-numerical vectors into characters. Also, we delete the Name variable because we don’t need players’ names in this case study.

mlb<- read.table('https://umich.instructure.com/files/330381/download?download_frd=1', as.is=T, header=T)
str(mlb)
## 'data.frame':    1034 obs. of  6 variables:
##  $ Name    : chr  "Adam_Donachie" "Paul_Bako" "Ramon_Hernandez" "Kevin_Millar" ...
##  $ Team    : chr  "BAL" "BAL" "BAL" "BAL" ...
##  $ Position: chr  "Catcher" "Catcher" "Catcher" "First_Baseman" ...
##  $ Height  : int  74 74 72 72 73 69 69 71 76 71 ...
##  $ Weight  : int  180 215 210 210 188 176 209 200 231 180 ...
##  $ Age     : num  23 34.7 30.8 35.4 35.7 ...
mlb<-mlb[, -1]

By looking at the str() output we notice that the variable TEAM and Position are misspecified as characters. To fix this we can use function as.factor() that convert numerical or character vectors to factors.

mlb$Team<-as.factor(mlb$Team)
mlb$Position<-as.factor(mlb$Position)

The data is good to go. Let’s explore it using some summary statistics and plots.

summary(mlb$Weight)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   150.0   187.0   200.0   201.7   215.0   290.0
hist(mlb$Weight, main = "Histogram for Weights")

plot_ly(x = mlb$Weight, type = "histogram", name= "Histogram for Weights") %>%
  layout(title="Baseball Players' Weight Histogram", bargap=0.1,
         xaxis=list(title="Weight"),  # control the y:x axes aspect ratio
         yaxis = list(title="Frequency"))

The above plot illustrates our dependent variable Weight. As we learned from Chapter 2, we know this is somewhat right-skewed.

Applying GGpairs to obtain a compact dataset summary we can mark heavy weight and light weight players (according to \(light \lt median \lt heavy\)) by different colors in the plot:

# require(GGally)
# mlb_binary = mlb
# mlb_binary$bi_weight = as.factor(ifelse(mlb_binary$Weight>median(mlb_binary$Weight),1,0))
# g_weight <- ggpairs(data=mlb_binary[-1], title="MLB Light/Heavy Weights",
#            mapping=ggplot2::aes(colour = bi_weight),
#            lower=list(combo=wrap("facethist",binwidth=1)),
#            # upper = list(continuous = wrap("cor", size = 4.75, alignPercent = 1))
#            )
# g_weight

plot_ly(mlb) %>%
    add_trace(type = 'splom', dimensions = list( list(label='Position', values=~Position), 
    list(label='Height', values=~Height), list(label='Weight', values=~Weight), 
    list(label='Age', values=~Age), list(label='Team', values=~Team)),
        text=~Team,
        marker = list(color = as.integer(mlb$Team),
            size = 7, line = list(width = 1, color = 'rgb(230,230,230)')
        )
    ) %>%
    layout(title= 'MLB Pairs Plot', hovermode='closest', dragmode= 'select',
        plot_bgcolor='rgba(240,240,240, 0.95)')
# We may also mark player positions by different colors in the ggpairs plot
# g_position <- ggpairs(data=mlb[-1], title="MLB by Position",
#               mapping=ggplot2::aes(colour = Position),
#               lower=list(combo=wrap("facethist",binwidth=1)))
# g_position

What about our potential predictors?

table(mlb$Team)
## 
## ANA ARZ ATL BAL BOS CHC CIN CLE COL CWS DET FLA HOU  KC  LA MIN MLW NYM NYY OAK 
##  35  28  37  35  36  36  36  35  35  33  37  32  34  35  33  33  35  38  32  37 
## PHI PIT  SD SEA  SF STL  TB TEX TOR WAS 
##  36  35  33  34  34  32  33  35  34  36
table(mlb$Position)
## 
##           Catcher Designated_Hitter     First_Baseman        Outfielder 
##                76                18                55               194 
##    Relief_Pitcher    Second_Baseman         Shortstop  Starting_Pitcher 
##               315                58                52               221 
##     Third_Baseman 
##                45
summary(mlb$Height)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##    67.0    72.0    74.0    73.7    75.0    83.0
summary(mlb$Age)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   20.90   25.44   27.93   28.74   31.23   48.52

Here we have two numerical predictors, two categorical predictors and \(1034\) observations. Let’s see how R treats these three different classes of variables.

3.3 Exploring relationships among features - the correlation matrix

Before fitting a model, let’s examine the independence of our potential predictors and the dependent variable. Multiple linear regressions assume that predictors are all independent with each other. Is this assumption valid? As we mentioned earlier, cor() function can answer this question in pairwise manner. Note we only look at numerical variables.

cor(mlb[c("Weight", "Height", "Age")])
##           Weight      Height         Age
## Weight 1.0000000  0.53031802  0.15784706
## Height 0.5303180  1.00000000 -0.07367013
## Age    0.1578471 -0.07367013  1.00000000

Here we can see \(cor(y, x)=cor(x, y)\) and \(cov(x, x)=1\). Also, our Height variable is weakly (negatively) related to the players’ age. This looks very good and wouldn’t cause any multi-collinearity problem. If two of our predictors are highly correlated, they both provide almost the same information. Then that could cause multi-collinearity. A common practice is to delete one of them in the model.

In general multivariate regression analysis, we can use the variance inflation factors (VIFs) to detect potential multicollinearity between many covariates. The variance inflation factor (VIF) quantifies the amount of artificial inflation of the variance due to multicollinearity of the covariates. The \(VIF_l\)’s represent the expected inflation of the corresponding estimated variances. In a simple linear regression model with a single predictor \(x_l\), \(y_i=\beta_o + \beta_1 x_{i,l}+\epsilon_i,\) relative to the baseline variance, \(\sigma\), the lower bound (min) of the variance of the estimated effect-size, \(\beta_l\), is:

\[Var(\beta_l)_{min}=\frac{\sigma^2}{\sum_{i=1}^n{\left ( x_{i,l}-\bar{x}_l\right )^2}}.\]

This allows us to track the inflation of the \(\beta_l\) variance in the presence of correlated predictors in the regression model. Suppose the linear model includes \(k\) covariates with some of them multicollinear/correlated:

\[y_i=\beta_o+\beta_1x_{i,1} + \beta_2 x_{i,2} + \cdots + \underbrace{\beta_l x_{i,l}} + \cdots + \beta_k x_{i,k} +\epsilon_i.\]

Assume some of the predictors are correlated with the feature \({x_l}\), then the variance of it’s effect, \({\beta_l}\), will be inflated as follows:

\[Var(\beta_l)=\frac{\sigma^2}{\sum_{i=1}^n{\left ( x_{i,l}-\bar{x}_l\right )^2}}\times \frac{1}{1-R_l^2},\]

where \(R_l^2\) is the \(R^2\)-value computed by regressing the \(l^{th}\) feature on the remaining \((k-1)\) predictors. The stronger the linear dependence between the \(l^{th}\) feature and the remaining predictors, the larger the corresponding \(R_l^2\) value will be, and the smaller the denominator in the inflation factor (\(VIF_l\)) and the larger the variance estimate of \(\beta_l\).

The variance inflation factor (\(VIF_l\)) is the ratio of the two variances:

\[VIF_l=\frac{Var(\beta_l)}{Var(\beta_l)_{min}}= \frac{\frac{\sigma^2}{\sum_{i=1}^n{\left ( x_{i,l}-\bar{x}_l\right )^2}}\times \frac{1}{1-R_l^2}}{\frac{\sigma^2}{\sum_{i=1}^n{\left ( x_{i,l}-\bar{x}_l\right )^2}}}=\frac{1}{1-R_l^2}.\]

The regression model’s VIFs measure of how much the variance of the estimated regression coefficients, \(\beta_l\), may be “inflated” by unavoidable presence of multicollinearity among the model predictor features. \(VIF_l\sim 1\) implies that there is no multicollinearity involving the \(l^{th}\) predictor and the remaining features and hence the variance estimate of \(\beta_l\) is not inflated. On the other hand side, if \(VIF_l > 4\) suggests potential multicollinearity, and \(VIF_l> 10\) suggests serious multicollinearity requiring some model correction may be necessary as the variance estimates may be significantly biased.

We can use the function car::vif() to compute and report the VIF factors:

car::vif(lm(Weight ~ Height + Age, data=mlb))
##   Height      Age 
## 1.005457 1.005457

3.4 Multi-colinearity and feature-selection in high-dimensional data

In Chapters 16 (Feature Selection) and 17 (Controlled variable Selection), we will discuss the methods and computational strategies to identify salient features in high-dimensional datasets. Let’s briefly identify some practical approaches to address multi-colinearity problems and challenges related to large numbers of inter-dependencies in the data. Data that contains a large number of predictors is likely to include completely unrelated variables having high sample correlation. To see the nature of this problem assume we are generating a random Gaussian \(n\times k\) matrix, \(X=(X_1,X_2, \cdots, X_k)\) of \(k\) feature vectors, \(X_i, 1\leq i\leq k\), using IID standard normal random samples. Then, the expected maximum correlation between any pair of columns, \(\rho(X_{i_1},X_{i_2})\), can be large as \(k\gg n\). Even in this IID sampling problem, we still expect high rate of intrinsic and strong feature correlations. In general, this phenomenon is amplified for high-dimensional observational data, which would be expected to have high degree of colinearity. This problem presents a number of computational (e.g., function singularities and negative definite Hessian matrices), model-fitting, model-interpretation, and selection of salient predictors challenges.

There are some techniques that allow us to resolve such multi-colinearity issues in high-dimensional data. Let’s denote \(n\) to be the number of cases (samples, subjects, units, etc.) and and \(k\) be the number of features. Using divide-and-conquer, we can split the problem into two special cases:

  • When \(n \geq k\), we can use the VIF, to solve the problem parametrically,
  • When \(n \ll k\), VIF does not apply and we need to be more creative. In this case, some solution include:
    • Use a dimensionality-reduction (PCA, ICA, FA, SVD, PLSR, t-SNE) to reduce the problem to \(n \geq k'\) (using only the top \(k'\) bases, functions, or directions).
    • Compute the (Spearman’s rank-order based) pair correlation (matrix) and do some kind of feature selection, e.g., choosing only features with lower paired-correlations.
    • The Sure Independence Screening (SIS) technique is based on correlation learning utilizing the sample correlation between a response and a given predictor. SIS reduces the feature-dimension (\(k\)) to a moderate dimension \(O(n)\).
    • The basic SIS method estimates marginal linear correlations between predictor and responses, which can be done by fitting a simple linear model. Non-parametric Independence Screening (NIS) expands this model-based SIS strategy to use non-parametric models and allow more flexibility for the predictor ranking. Models’ diagnostics for predictor-ranking may use the magnitude of the marginal estimators, non-parametric marginal-correlations, or marginal residual sum of squares.
    • Generalized Correlation screening employs an empirical sample-driven estimate of a generalized correlation to rank the individual predictors.
    • Forward Regression using best subset regression is computationally very expensive because of the large combinatoric space, as the utility of each predictor depends on many other predictors. It generates a nested sequence of models, each having one additional predictor than the prior model. The model expansion adds new variables to the model based on their effect improve the model quality, e.g., the largest decrease of the regression sum of squares, compared to the prior model.
    • Model-Free Screening strategy basically uses empirical estimates for conditional densities of the response given the predictors. Most methods have consistency in ranking (CIR) property, which ensures that the objective utility function ranks unimportant predictors lower than important predictors with high probability (\(p \rightarrow 1\)).

3.5 Visualizing relationships among features - the scatterplot matrix

To visualize pairwise correlations, we could use scatterplot matrix, pairs(), ggpairs(), or plot_ly().

# pairs(mlb[c("Weight", "Height", "Age")])
plot_ly(mlb) %>%
    add_trace(type = 'splom', dimensions = list( list(label='Height', values=~Height),
              list(label='Weight', values=~Weight), list(label='Age', values=~Age)),
        text=~Position,
        marker = list(color = as.integer(mlb$Team),
            size = 7, line = list(width = 1, color = 'rgb(230,230,230)')
        )
    ) %>%
    layout(title= 'MLB Pairs Plot', hovermode='closest', dragmode= 'select',
        plot_bgcolor='rgba(240,240,240, 0.95)')

You might get a sense of it but it is difficult to see any linear pattern. We can make an more sophisticated graph using pairs.panels() in the psych package.

# install.packages("psych")
library(psych)
## 
## Attaching package: 'psych'
## The following objects are masked from 'package:ggplot2':
## 
##     %+%, alpha
pairs.panels(mlb[, c("Weight", "Height", "Age")])

This plot give us much more information about the three variables. Above the diagonal, we have our correlation coefficients in numerical form. On the diagonal, there are histograms of variables. Below the diagonal, more visual information are presented to help us understand the trend. This specific graph shows us height and weight are positively and strongly correlated. Also the relationships between age and height, age and weight are very weak (horizontal red line in the below diagonal graphs indicates weak relationships).

3.6 Step 3 - training a model on the data

The function we are going to use for this section is lm(). No extra package is needed when using this function.

The lm() function has the following components:

m<-lm(dv ~ iv, data=mydata)

  • dv: dependent variable
  • iv: independent variables. Just like OneR() in Chapter 8. If we use . as iv, all of the variables, except the dependent variable (\(dv\)), are included as predictors.
  • data: specifies the data containing both dependent viable and independent variables
fit <- lm(Weight ~ ., data=mlb)
fit
## 
## Call:
## lm(formula = Weight ~ ., data = mlb)
## 
## Coefficients:
##               (Intercept)                    TeamARZ  
##                 -164.9995                     7.1881  
##                   TeamATL                    TeamBAL  
##                   -1.5631                    -5.3128  
##                   TeamBOS                    TeamCHC  
##                   -0.2838                     0.4026  
##                   TeamCIN                    TeamCLE  
##                    2.1051                    -1.3160  
##                   TeamCOL                    TeamCWS  
##                   -3.7836                     4.2944  
##                   TeamDET                    TeamFLA  
##                    2.3024                     2.6985  
##                   TeamHOU                     TeamKC  
##                   -0.6808                    -4.7664  
##                    TeamLA                    TeamMIN  
##                    2.8598                     2.1269  
##                   TeamMLW                    TeamNYM  
##                    4.2897                    -1.9736  
##                   TeamNYY                    TeamOAK  
##                    1.7483                    -0.5464  
##                   TeamPHI                    TeamPIT  
##                   -6.8486                     4.3023  
##                    TeamSD                    TeamSEA  
##                    2.6133                    -0.9147  
##                    TeamSF                    TeamSTL  
##                    0.8411                    -1.1341  
##                    TeamTB                    TeamTEX  
##                   -2.6616                    -0.7695  
##                   TeamTOR                    TeamWAS  
##                    1.3943                    -1.7555  
## PositionDesignated_Hitter      PositionFirst_Baseman  
##                    8.9037                     2.4237  
##        PositionOutfielder     PositionRelief_Pitcher  
##                   -6.2636                    -7.7695  
##    PositionSecond_Baseman          PositionShortstop  
##                  -13.0843                   -16.9562  
##  PositionStarting_Pitcher      PositionThird_Baseman  
##                   -7.3599                    -4.6035  
##                    Height                        Age  
##                    4.7175                     0.8906

As we can see from the output, factors are included in the model by creating several indicators, one coefficient for each factor level (except for all reference factor levels). Each numerical variables just have one coefficient.

3.7 Step 4 - evaluating model performance

As we did in previous case studies, let’s examine model performance.

summary(fit)
## 
## Call:
## lm(formula = Weight ~ ., data = mlb)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -48.692 -10.909  -0.778   9.858  73.649 
## 
## Coefficients:
##                            Estimate Std. Error t value Pr(>|t|)    
## (Intercept)               -164.9995    19.3828  -8.513  < 2e-16 ***
## TeamARZ                      7.1881     4.2590   1.688 0.091777 .  
## TeamATL                     -1.5631     3.9757  -0.393 0.694278    
## TeamBAL                     -5.3128     4.0193  -1.322 0.186533    
## TeamBOS                     -0.2838     4.0034  -0.071 0.943492    
## TeamCHC                      0.4026     3.9949   0.101 0.919749    
## TeamCIN                      2.1051     3.9934   0.527 0.598211    
## TeamCLE                     -1.3160     4.0356  -0.326 0.744423    
## TeamCOL                     -3.7836     4.0287  -0.939 0.347881    
## TeamCWS                      4.2944     4.1022   1.047 0.295413    
## TeamDET                      2.3024     3.9725   0.580 0.562326    
## TeamFLA                      2.6985     4.1336   0.653 0.514028    
## TeamHOU                     -0.6808     4.0634  -0.168 0.866976    
## TeamKC                      -4.7664     4.0242  -1.184 0.236525    
## TeamLA                       2.8598     4.0817   0.701 0.483686    
## TeamMIN                      2.1269     4.0947   0.519 0.603579    
## TeamMLW                      4.2897     4.0243   1.066 0.286706    
## TeamNYM                     -1.9736     3.9493  -0.500 0.617370    
## TeamNYY                      1.7483     4.1234   0.424 0.671655    
## TeamOAK                     -0.5464     3.9672  -0.138 0.890474    
## TeamPHI                     -6.8486     3.9949  -1.714 0.086778 .  
## TeamPIT                      4.3023     4.0210   1.070 0.284890    
## TeamSD                       2.6133     4.0915   0.639 0.523148    
## TeamSEA                     -0.9147     4.0516  -0.226 0.821436    
## TeamSF                       0.8411     4.0520   0.208 0.835593    
## TeamSTL                     -1.1341     4.1193  -0.275 0.783132    
## TeamTB                      -2.6616     4.0944  -0.650 0.515798    
## TeamTEX                     -0.7695     4.0283  -0.191 0.848556    
## TeamTOR                      1.3943     4.0681   0.343 0.731871    
## TeamWAS                     -1.7555     4.0038  -0.438 0.661142    
## PositionDesignated_Hitter    8.9037     4.4533   1.999 0.045842 *  
## PositionFirst_Baseman        2.4237     3.0058   0.806 0.420236    
## PositionOutfielder          -6.2636     2.2784  -2.749 0.006084 ** 
## PositionRelief_Pitcher      -7.7695     2.1959  -3.538 0.000421 ***
## PositionSecond_Baseman     -13.0843     2.9638  -4.415 1.12e-05 ***
## PositionShortstop          -16.9562     3.0406  -5.577 3.16e-08 ***
## PositionStarting_Pitcher    -7.3599     2.2976  -3.203 0.001402 ** 
## PositionThird_Baseman       -4.6035     3.1689  -1.453 0.146613    
## Height                       4.7175     0.2563  18.405  < 2e-16 ***
## Age                          0.8906     0.1259   7.075 2.82e-12 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 16.78 on 994 degrees of freedom
## Multiple R-squared:  0.3858, Adjusted R-squared:  0.3617 
## F-statistic: 16.01 on 39 and 994 DF,  p-value: < 2.2e-16
#plot(fit, which = 1:2)

plot_ly(x=fit$fitted.values, y=fit$residuals, type="scatter", mode="markers") %>%
  layout(title="LM: Fitted-values vs. Model-Residuals",
         xaxis=list(title="Fitted"), 
         yaxis = list(title="Residuals"))

The summary shows us how well does the model fits the dataset.

Residuals: This tells us about the residuals. If we have extremely large or extremely small residuals for some observations compared to the rest of residuals, either they are outliers due to reporting error or the model fits data poorly. We have \(73.649\) as our maximum and \(-48.692\) as our minimum. Their extremeness could be examined by residual diagnostic plot.

Coefficients: In this section, we look at the very right column that has stars. Stars or dots next to variables show if the variable is significant and should be included in the model. However, if nothing is next to a variable then it means this estimated covariance could be 0 in the linear model. Another thing we can look at is the Pr(>|t|) column. A number closed to 0 in this column indicates the row variable is significant, otherwise it could be deleted from the model.

Here only some of the teams and positions are not significant. Age and Height are significant.

R-squared: What percent in y is explained by included predictors. Here we have 38.58%, which indicates the model is not bad but could be improved. Usually a well-fitted linear regression would have over 70%.

The diagnostic plots also helps us understanding the situation.

Residual vs Fitted: This is the residual diagnostic plot. We can see that the residuals of observations indexed \(65\), \(160\) and \(237\) are relatively far apart from the rest. They are potential influential points or outliers.

Normal Q-Q: This plot examines the normality assumption of the model. If these dots follows the line on the graph, the normality assumption is valid. In our case, it is relatively close to the line. So, we can say that our model is valid in terms of normality.

3.8 Step 5 - improving model performance

We can employ the step function to perform forward or backward selection of important features/predictors. It works for both lm and glm models. In most cases, backward-selection is preferable because it tends to retain much larger models. On the other hand, there are various criteria to evaluate a model. Commonly used criteria include AIC, BIC, Adjusted \(R^2\), etc. Let’s compare the backward and forward model selection approaches. The step function argument direction allows this control (default is both, which will select the better result from either backward or forward selection). Later, in Chapter 16 and Chapter 17, we will present details about alternative feature selection approaches.

step(fit,direction = "backward")
## Start:  AIC=5871.04
## Weight ~ Team + Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## - Team     29      9468 289262 5847.4
## <none>                  279793 5871.0
## - Age       1     14090 293883 5919.8
## - Position  8     20301 300095 5927.5
## - Height    1     95356 375149 6172.3
## 
## Step:  AIC=5847.45
## Weight ~ Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## <none>                  289262 5847.4
## - Age       1     14616 303877 5896.4
## - Position  8     20406 309668 5901.9
## - Height    1    100435 389697 6153.6
## 
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
## 
## Coefficients:
##               (Intercept)  PositionDesignated_Hitter  
##                 -168.0474                     8.6968  
##     PositionFirst_Baseman         PositionOutfielder  
##                    2.7780                    -6.0457  
##    PositionRelief_Pitcher     PositionSecond_Baseman  
##                   -7.7782                   -13.0267  
##         PositionShortstop   PositionStarting_Pitcher  
##                  -16.4821                    -7.3961  
##     PositionThird_Baseman                     Height  
##                   -4.1361                     4.7639  
##                       Age  
##                    0.8771
step(fit,direction = "forward")
## Start:  AIC=5871.04
## Weight ~ Team + Position + Height + Age
## 
## Call:
## lm(formula = Weight ~ Team + Position + Height + Age, data = mlb)
## 
## Coefficients:
##               (Intercept)                    TeamARZ  
##                 -164.9995                     7.1881  
##                   TeamATL                    TeamBAL  
##                   -1.5631                    -5.3128  
##                   TeamBOS                    TeamCHC  
##                   -0.2838                     0.4026  
##                   TeamCIN                    TeamCLE  
##                    2.1051                    -1.3160  
##                   TeamCOL                    TeamCWS  
##                   -3.7836                     4.2944  
##                   TeamDET                    TeamFLA  
##                    2.3024                     2.6985  
##                   TeamHOU                     TeamKC  
##                   -0.6808                    -4.7664  
##                    TeamLA                    TeamMIN  
##                    2.8598                     2.1269  
##                   TeamMLW                    TeamNYM  
##                    4.2897                    -1.9736  
##                   TeamNYY                    TeamOAK  
##                    1.7483                    -0.5464  
##                   TeamPHI                    TeamPIT  
##                   -6.8486                     4.3023  
##                    TeamSD                    TeamSEA  
##                    2.6133                    -0.9147  
##                    TeamSF                    TeamSTL  
##                    0.8411                    -1.1341  
##                    TeamTB                    TeamTEX  
##                   -2.6616                    -0.7695  
##                   TeamTOR                    TeamWAS  
##                    1.3943                    -1.7555  
## PositionDesignated_Hitter      PositionFirst_Baseman  
##                    8.9037                     2.4237  
##        PositionOutfielder     PositionRelief_Pitcher  
##                   -6.2636                    -7.7695  
##    PositionSecond_Baseman          PositionShortstop  
##                  -13.0843                   -16.9562  
##  PositionStarting_Pitcher      PositionThird_Baseman  
##                   -7.3599                    -4.6035  
##                    Height                        Age  
##                    4.7175                     0.8906
step(fit,direction = "both")
## Start:  AIC=5871.04
## Weight ~ Team + Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## - Team     29      9468 289262 5847.4
## <none>                  279793 5871.0
## - Age       1     14090 293883 5919.8
## - Position  8     20301 300095 5927.5
## - Height    1     95356 375149 6172.3
## 
## Step:  AIC=5847.45
## Weight ~ Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## <none>                  289262 5847.4
## + Team     29      9468 279793 5871.0
## - Age       1     14616 303877 5896.4
## - Position  8     20406 309668 5901.9
## - Height    1    100435 389697 6153.6
## 
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
## 
## Coefficients:
##               (Intercept)  PositionDesignated_Hitter  
##                 -168.0474                     8.6968  
##     PositionFirst_Baseman         PositionOutfielder  
##                    2.7780                    -6.0457  
##    PositionRelief_Pitcher     PositionSecond_Baseman  
##                   -7.7782                   -13.0267  
##         PositionShortstop   PositionStarting_Pitcher  
##                  -16.4821                    -7.3961  
##     PositionThird_Baseman                     Height  
##                   -4.1361                     4.7639  
##                       Age  
##                    0.8771

We can observe that forward selection retains the whole model. The better feature selection model uses backward step-wise selection.

Both backward and forward feature selection methods utilize greedy algorithms and do not guarantees an optimal model selection result. Identifying the best feature selection requires exploring every possible combination of the predictors, which is practically not feasible, due to computational complexity associated with model selection using \(n \choose k\) combinations of features.

Alternatively, we can choose models based on various information criteria.

step(fit,k=2)
## Start:  AIC=5871.04
## Weight ~ Team + Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## - Team     29      9468 289262 5847.4
## <none>                  279793 5871.0
## - Age       1     14090 293883 5919.8
## - Position  8     20301 300095 5927.5
## - Height    1     95356 375149 6172.3
## 
## Step:  AIC=5847.45
## Weight ~ Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## <none>                  289262 5847.4
## - Age       1     14616 303877 5896.4
## - Position  8     20406 309668 5901.9
## - Height    1    100435 389697 6153.6
## 
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
## 
## Coefficients:
##               (Intercept)  PositionDesignated_Hitter  
##                 -168.0474                     8.6968  
##     PositionFirst_Baseman         PositionOutfielder  
##                    2.7780                    -6.0457  
##    PositionRelief_Pitcher     PositionSecond_Baseman  
##                   -7.7782                   -13.0267  
##         PositionShortstop   PositionStarting_Pitcher  
##                  -16.4821                    -7.3961  
##     PositionThird_Baseman                     Height  
##                   -4.1361                     4.7639  
##                       Age  
##                    0.8771
step(fit,k=log(nrow(mlb)))
## Start:  AIC=6068.69
## Weight ~ Team + Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## - Team     29      9468 289262 5901.8
## <none>                  279793 6068.7
## - Position  8     20301 300095 6085.6
## - Age       1     14090 293883 6112.5
## - Height    1     95356 375149 6365.0
## 
## Step:  AIC=5901.8
## Weight ~ Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## <none>                  289262 5901.8
## - Position  8     20406 309668 5916.8
## - Age       1     14616 303877 5945.8
## - Height    1    100435 389697 6203.0
## 
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
## 
## Coefficients:
##               (Intercept)  PositionDesignated_Hitter  
##                 -168.0474                     8.6968  
##     PositionFirst_Baseman         PositionOutfielder  
##                    2.7780                    -6.0457  
##    PositionRelief_Pitcher     PositionSecond_Baseman  
##                   -7.7782                   -13.0267  
##         PositionShortstop   PositionStarting_Pitcher  
##                  -16.4821                    -7.3961  
##     PositionThird_Baseman                     Height  
##                   -4.1361                     4.7639  
##                       Age  
##                    0.8771

\(k = 2\) yields the genuine AIC criterion, and \(k = log(n)\) refers to BIC. Let’s try to evaluate model performance again.

fit2 = step(fit,k=2,direction = "backward")
## Start:  AIC=5871.04
## Weight ~ Team + Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## - Team     29      9468 289262 5847.4
## <none>                  279793 5871.0
## - Age       1     14090 293883 5919.8
## - Position  8     20301 300095 5927.5
## - Height    1     95356 375149 6172.3
## 
## Step:  AIC=5847.45
## Weight ~ Position + Height + Age
## 
##            Df Sum of Sq    RSS    AIC
## <none>                  289262 5847.4
## - Age       1     14616 303877 5896.4
## - Position  8     20406 309668 5901.9
## - Height    1    100435 389697 6153.6
summary(fit2)
## 
## Call:
## lm(formula = Weight ~ Position + Height + Age, data = mlb)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -49.427 -10.855  -0.344  10.110  75.301 
## 
## Coefficients:
##                            Estimate Std. Error t value Pr(>|t|)    
## (Intercept)               -168.0474    19.0351  -8.828  < 2e-16 ***
## PositionDesignated_Hitter    8.6968     4.4258   1.965 0.049679 *  
## PositionFirst_Baseman        2.7780     2.9942   0.928 0.353741    
## PositionOutfielder          -6.0457     2.2778  -2.654 0.008072 ** 
## PositionRelief_Pitcher      -7.7782     2.1913  -3.550 0.000403 ***
## PositionSecond_Baseman     -13.0267     2.9531  -4.411 1.14e-05 ***
## PositionShortstop          -16.4821     3.0372  -5.427 7.16e-08 ***
## PositionStarting_Pitcher    -7.3961     2.2959  -3.221 0.001316 ** 
## PositionThird_Baseman       -4.1361     3.1656  -1.307 0.191647    
## Height                       4.7639     0.2528  18.847  < 2e-16 ***
## Age                          0.8771     0.1220   7.190 1.25e-12 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 16.82 on 1023 degrees of freedom
## Multiple R-squared:  0.365,  Adjusted R-squared:  0.3588 
## F-statistic: 58.81 on 10 and 1023 DF,  p-value: < 2.2e-16
#plot(fit2, which = 1:2)
plot_ly(x=fit2$fitted.values, y=fit2$residuals, type="scatter", mode="markers") %>%
  layout(title="LM: Fitted-values vs. Model-Residuals",
         xaxis=list(title="Fitted"), 
         yaxis = list(title="Residuals"))
# compute the quantiles
QQ <- qqplot(fit2$fitted.values, fit2$residuals, plot.it=FALSE)
# take a smaller sample size to expedite the viz
ind <-  sample(1:length(QQ$x), 1000, replace = FALSE)
plot_ly() %>%
  add_markers(x=~QQ$x, y=~QQ$y, name="Quantiles Scatter", type="scatter", mode="markers") %>%
  add_trace(x = ~c(160,260), y = ~c(-50,80), type="scatter", mode="lines",
        line = list(color = "red", width = 4), name="Line", showlegend=F) %>%
  layout(title='Quantile plot',
           xaxis = list(title="Fitted"),
           yaxis = list(title="Residuals"),
           legend = list(orientation = 'h'))

Sometimes, simpler models are preferable even if there is a little bit loss of performance. In this case, we have a simpler model and \(R^2=0.365\). The whole model is still very significant. We can see that observations \(65\), \(160\) and \(237\) are relatively far from other residuals. They are potential influential points or outliers.

Also, we can observe the leverage points. Leverage points are those either outliers or influential points or both. In a regression model setting, observation leverage is the relative distance of the observation (data point) from the mean of the explanatory variable. Observations near the mean of the explanatory variable have low leverage and those far from the mean have high leverage. Yet, not all points of high leverage are necessarily influential.

# Half-normal plot for leverages
# install.packages("faraway")
library(faraway)
halfnorm(lm.influence(fit)$hat, nlab = 2, ylab="Leverages")

mlb[c(226,879),]
##     Team          Position Height Weight   Age
## 226  NYY Designated_Hitter     75    230 36.14
## 879   SD Designated_Hitter     73    200 25.60
summary(mlb)
##       Team                 Position       Height         Weight     
##  NYM    : 38   Relief_Pitcher  :315   Min.   :67.0   Min.   :150.0  
##  ATL    : 37   Starting_Pitcher:221   1st Qu.:72.0   1st Qu.:187.0  
##  DET    : 37   Outfielder      :194   Median :74.0   Median :200.0  
##  OAK    : 37   Catcher         : 76   Mean   :73.7   Mean   :201.7  
##  BOS    : 36   Second_Baseman  : 58   3rd Qu.:75.0   3rd Qu.:215.0  
##  CHC    : 36   First_Baseman   : 55   Max.   :83.0   Max.   :290.0  
##  (Other):813   (Other)         :115                                 
##       Age       
##  Min.   :20.90  
##  1st Qu.:25.44  
##  Median :27.93  
##  Mean   :28.74  
##  3rd Qu.:31.23  
##  Max.   :48.52  
## 

A deeper discussion of variable selection, controlling the false discovery rate, is provided in Chapters 16 and 17.

3.8.1 Model specification - adding non-linear relationships

In linear regression, the relationship between independent and dependent variables is assumed to be linear. However, this might not be the case. The relationship between age and weight could be quadratic, since middle-aged people might gain weight dramatically.

mlb$age2<-(mlb$Age)^2
fit2<-lm(Weight ~ ., data=mlb)
summary(fit2)
## 
## Call:
## lm(formula = Weight ~ ., data = mlb)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -49.068 -10.775  -1.021   9.922  74.693 
## 
## Coefficients:
##                             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)               -209.07068   27.49529  -7.604 6.65e-14 ***
## TeamARZ                      7.41943    4.25154   1.745 0.081274 .  
## TeamATL                     -1.43167    3.96793  -0.361 0.718318    
## TeamBAL                     -5.38735    4.01119  -1.343 0.179552    
## TeamBOS                     -0.06614    3.99633  -0.017 0.986799    
## TeamCHC                      0.14541    3.98833   0.036 0.970923    
## TeamCIN                      2.24022    3.98571   0.562 0.574201    
## TeamCLE                     -1.07546    4.02870  -0.267 0.789563    
## TeamCOL                     -3.87254    4.02069  -0.963 0.335705    
## TeamCWS                      4.20933    4.09393   1.028 0.304111    
## TeamDET                      2.66990    3.96769   0.673 0.501160    
## TeamFLA                      3.14627    4.12989   0.762 0.446343    
## TeamHOU                     -0.77230    4.05526  -0.190 0.849000    
## TeamKC                      -4.90984    4.01648  -1.222 0.221837    
## TeamLA                       3.13554    4.07514   0.769 0.441820    
## TeamMIN                      2.09951    4.08631   0.514 0.607512    
## TeamMLW                      4.16183    4.01646   1.036 0.300363    
## TeamNYM                     -1.25057    3.95424  -0.316 0.751870    
## TeamNYY                      1.67825    4.11502   0.408 0.683482    
## TeamOAK                     -0.68235    3.95951  -0.172 0.863212    
## TeamPHI                     -6.85071    3.98672  -1.718 0.086039 .  
## TeamPIT                      4.12683    4.01348   1.028 0.304086    
## TeamSD                       2.59525    4.08310   0.636 0.525179    
## TeamSEA                     -0.67316    4.04471  -0.166 0.867853    
## TeamSF                       1.06038    4.04481   0.262 0.793255    
## TeamSTL                     -1.38669    4.11234  -0.337 0.736037    
## TeamTB                      -2.44396    4.08716  -0.598 0.550003    
## TeamTEX                     -0.68740    4.02023  -0.171 0.864270    
## TeamTOR                      1.24439    4.06029   0.306 0.759306    
## TeamWAS                     -1.87599    3.99594  -0.469 0.638835    
## PositionDesignated_Hitter    8.94440    4.44417   2.013 0.044425 *  
## PositionFirst_Baseman        2.55100    3.00014   0.850 0.395368    
## PositionOutfielder          -6.25702    2.27372  -2.752 0.006033 ** 
## PositionRelief_Pitcher      -7.68904    2.19166  -3.508 0.000471 ***
## PositionSecond_Baseman     -13.01400    2.95787  -4.400 1.20e-05 ***
## PositionShortstop          -16.82243    3.03494  -5.543 3.81e-08 ***
## PositionStarting_Pitcher    -7.08215    2.29615  -3.084 0.002096 ** 
## PositionThird_Baseman       -4.66452    3.16249  -1.475 0.140542    
## Height                       4.71888    0.25578  18.449  < 2e-16 ***
## Age                          3.82295    1.30621   2.927 0.003503 ** 
## age2                        -0.04791    0.02124  -2.255 0.024327 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 16.74 on 993 degrees of freedom
## Multiple R-squared:  0.3889, Adjusted R-squared:  0.3643 
## F-statistic:  15.8 on 40 and 993 DF,  p-value: < 2.2e-16

This actually brought up the overall \(R^2\) up to \(0.3889\).

3.9 Transformation - converting a numeric variable to a binary indicator

As discussed earlier, middle-aged people might have a different pattern in weight increase compared to younger people. The overall pattern could be not cumulative, but rather two separate lines for young and middle-aged people. We assume 30 is the threshold. People over 30 have a steeper line for weight increase than under 30. Here we use the ifelse() function that we mentioned in Chapter 7 to create the indicator of the threshold.

mlb$age30<-ifelse(mlb$Age>=30, 1, 0)
fit3<-lm(Weight~Team+Position+Age+age30+Height, data=mlb)
summary(fit3)
## 
## Call:
## lm(formula = Weight ~ Team + Position + Age + age30 + Height, 
##     data = mlb)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -48.313 -11.166  -0.916  10.044  73.630 
## 
## Coefficients:
##                            Estimate Std. Error t value Pr(>|t|)    
## (Intercept)               -159.8884    19.8862  -8.040 2.54e-15 ***
## TeamARZ                      7.4096     4.2627   1.738 0.082483 .  
## TeamATL                     -1.4379     3.9765  -0.362 0.717727    
## TeamBAL                     -5.3393     4.0187  -1.329 0.184284    
## TeamBOS                     -0.1985     4.0034  -0.050 0.960470    
## TeamCHC                      0.4669     3.9947   0.117 0.906976    
## TeamCIN                      2.2124     3.9939   0.554 0.579741    
## TeamCLE                     -1.1624     4.0371  -0.288 0.773464    
## TeamCOL                     -3.6842     4.0290  -0.914 0.360717    
## TeamCWS                      4.1920     4.1025   1.022 0.307113    
## TeamDET                      2.4708     3.9746   0.622 0.534314    
## TeamFLA                      2.8563     4.1352   0.691 0.489903    
## TeamHOU                     -0.4964     4.0659  -0.122 0.902846    
## TeamKC                      -4.7138     4.0238  -1.171 0.241692    
## TeamLA                       2.9194     4.0814   0.715 0.474586    
## TeamMIN                      2.2885     4.0965   0.559 0.576528    
## TeamMLW                      4.4749     4.0269   1.111 0.266731    
## TeamNYM                     -1.8173     3.9510  -0.460 0.645659    
## TeamNYY                      1.7074     4.1229   0.414 0.678867    
## TeamOAK                     -0.3388     3.9707  -0.085 0.932012    
## TeamPHI                     -6.6192     3.9993  -1.655 0.098220 .  
## TeamPIT                      4.6716     4.0332   1.158 0.247029    
## TeamSD                       2.8600     4.0965   0.698 0.485243    
## TeamSEA                     -1.0121     4.0518  -0.250 0.802809    
## TeamSF                       1.0244     4.0545   0.253 0.800587    
## TeamSTL                     -1.1094     4.1187  -0.269 0.787703    
## TeamTB                      -2.4485     4.0980  -0.597 0.550312    
## TeamTEX                     -0.6112     4.0300  -0.152 0.879485    
## TeamTOR                      1.3959     4.0674   0.343 0.731532    
## TeamWAS                     -1.4189     4.0139  -0.354 0.723784    
## PositionDesignated_Hitter    9.2378     4.4621   2.070 0.038683 *  
## PositionFirst_Baseman        2.6074     3.0096   0.866 0.386501    
## PositionOutfielder          -6.0408     2.2863  -2.642 0.008367 ** 
## PositionRelief_Pitcher      -7.5100     2.2072  -3.403 0.000694 ***
## PositionSecond_Baseman     -12.8870     2.9683  -4.342 1.56e-05 ***
## PositionShortstop          -16.8912     3.0406  -5.555 3.56e-08 ***
## PositionStarting_Pitcher    -7.0825     2.3099  -3.066 0.002227 ** 
## PositionThird_Baseman       -4.4307     3.1719  -1.397 0.162773    
## Age                          0.6904     0.2153   3.207 0.001386 ** 
## age30                        2.2636     1.9749   1.146 0.251992    
## Height                       4.7113     0.2563  18.380  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 16.77 on 993 degrees of freedom
## Multiple R-squared:  0.3866, Adjusted R-squared:  0.3619 
## F-statistic: 15.65 on 40 and 993 DF,  p-value: < 2.2e-16

This model performs worse than the quadratic model in terms of \(R^2\). Moreover, age30 is not significant. So, we will stick with the quadratic model.

3.10 Model specification - adding interaction effects

So far, each feature’s individual effect has considered in our model. It is possible that features act in pairs to affect our independent variable. Let’s examine that deeper.

Interaction is a combined effect by two features. If we are not sure whether two features interact, we could test by adding an interaction term into the model. If the interaction is significant then we confirmed an interaction between two features.

fit4<-lm(Weight~Team+Height+Age*Position+age2, data=mlb)
summary(fit4)
## 
## Call:
## lm(formula = Weight ~ Team + Height + Age * Position + age2, 
##     data = mlb)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -48.761 -11.049  -0.761   9.911  75.533 
## 
## Coefficients:
##                                 Estimate Std. Error t value Pr(>|t|)    
## (Intercept)                   -199.15403   29.87269  -6.667 4.35e-11 ***
## TeamARZ                          8.10376    4.26339   1.901   0.0576 .  
## TeamATL                         -0.81743    3.97899  -0.205   0.8373    
## TeamBAL                         -4.64820    4.03972  -1.151   0.2502    
## TeamBOS                          0.37698    4.00743   0.094   0.9251    
## TeamCHC                          0.33104    3.99507   0.083   0.9340    
## TeamCIN                          2.56023    3.99603   0.641   0.5219    
## TeamCLE                         -0.66254    4.03154  -0.164   0.8695    
## TeamCOL                         -3.72098    4.03759  -0.922   0.3570    
## TeamCWS                          4.63266    4.10884   1.127   0.2598    
## TeamDET                          3.21380    3.98231   0.807   0.4199    
## TeamFLA                          3.56432    4.14902   0.859   0.3905    
## TeamHOU                         -0.38733    4.07249  -0.095   0.9242    
## TeamKC                          -4.66678    4.02384  -1.160   0.2464    
## TeamLA                           3.51766    4.09400   0.859   0.3904    
## TeamMIN                          2.31585    4.10502   0.564   0.5728    
## TeamMLW                          4.34793    4.02501   1.080   0.2803    
## TeamNYM                         -0.28505    3.98537  -0.072   0.9430    
## TeamNYY                          1.87847    4.12774   0.455   0.6491    
## TeamOAK                         -0.23791    3.97729  -0.060   0.9523    
## TeamPHI                         -6.25671    3.99545  -1.566   0.1177    
## TeamPIT                          4.18719    4.01944   1.042   0.2978    
## TeamSD                           2.97028    4.08838   0.727   0.4677    
## TeamSEA                         -0.07220    4.05922  -0.018   0.9858    
## TeamSF                           1.35981    4.07771   0.333   0.7388    
## TeamSTL                         -1.23460    4.11960  -0.300   0.7645    
## TeamTB                          -1.90885    4.09592  -0.466   0.6413    
## TeamTEX                         -0.31570    4.03146  -0.078   0.9376    
## TeamTOR                          1.73976    4.08565   0.426   0.6703    
## TeamWAS                         -1.43933    4.00274  -0.360   0.7192    
## Height                           4.70632    0.25646  18.351  < 2e-16 ***
## Age                              3.32733    1.37088   2.427   0.0154 *  
## PositionDesignated_Hitter      -44.82216   30.68202  -1.461   0.1444    
## PositionFirst_Baseman           23.51389   20.23553   1.162   0.2455    
## PositionOutfielder             -13.33140   15.92500  -0.837   0.4027    
## PositionRelief_Pitcher         -16.51308   15.01240  -1.100   0.2716    
## PositionSecond_Baseman         -26.56932   20.18773  -1.316   0.1884    
## PositionShortstop              -27.89454   20.72123  -1.346   0.1786    
## PositionStarting_Pitcher        -2.44578   15.36376  -0.159   0.8736    
## PositionThird_Baseman          -10.20102   23.26121  -0.439   0.6611    
## age2                            -0.04201    0.02170  -1.936   0.0531 .  
## Age:PositionDesignated_Hitter    1.77289    1.00506   1.764   0.0780 .  
## Age:PositionFirst_Baseman       -0.71111    0.67848  -1.048   0.2949    
## Age:PositionOutfielder           0.24147    0.53650   0.450   0.6527    
## Age:PositionRelief_Pitcher       0.30374    0.50564   0.601   0.5482    
## Age:PositionSecond_Baseman       0.46281    0.68281   0.678   0.4981    
## Age:PositionShortstop            0.38257    0.70998   0.539   0.5901    
## Age:PositionStarting_Pitcher    -0.17104    0.51976  -0.329   0.7422    
## Age:PositionThird_Baseman        0.18968    0.79561   0.238   0.8116    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 16.73 on 985 degrees of freedom
## Multiple R-squared:  0.3945, Adjusted R-squared:  0.365 
## F-statistic: 13.37 on 48 and 985 DF,  p-value: < 2.2e-16

Here we can see the overall \(R^2\) improved and some of the interactions are significant under 0.1 level.

4 Understanding regression trees and model trees

4.1 Motivation

As we saw in Chapter 8, a decision tree builds by multiple if-else logical decisions and can classify observations. We could add regression into decision trees so that a decision tree can make numerical predictions.

4.2 Adding regression to trees

Numeric prediction trees are built in the same way as classification trees. Recall the discussion in Chapter 8 where the data are partitioned first via a divide-and-conquer strategy based on features. The homogeneity of the resulting classification trees is measured by various metrics, e.g., entropy. In regression-tree prediction, node homogeneity is measured by various statistics such as variance, standard deviation or absolute deviation from the mean.

A common splitting criterion for decision trees is the standard deviation reduction (SDR).

\[SDR=sd(T)-\sum_{i=1}^n \left | \frac{T_i}{T} \right | \times sd(T_i),\]

where sd(T) is the standard deviation for the original data. After the summation of all segments, \(|\frac{T_i}{T}|\) is the proportion of observations in \(i^{th}\) segment compared to total number of observations, and \(sd(T_i)\) is the standard deviation for the \(i^{th}\) segment.

An example for this would be: \[Original\, data:\{1, 2, 3, 3, 4, 5, 6, 6, 7, 8\}\] \[Split\, method\, 1:\{1, 2, 3|3, 4, 5, 6, 6, 7, 8\}\] \[Split\, method\, 2:\{1, 2, 3, 3, 4, 5|6, 6, 7, 8\}\] In split method 1, \(T_1=\{1, 2, 3\}\), \(T_2=\{3, 4, 5, 6, 6, 7, 8\}\). In split method 2, \(T_1=\{1, 2, 3, 3, 4, 5\}\), \(T_2=\{6, 6, 7, 8\}\).

ori<-c(1, 2, 3, 3, 4, 5, 6, 6, 7, 8)
at1<-c(1, 2, 3)
at2<-c(3, 4, 5, 6, 6, 7, 8)
bt1<-c(1, 2, 3, 3, 4, 5)
bt2<-c(6, 6, 7, 8)
sdr_a<-sd(ori)-(length(at1)/length(ori)*sd(at1)+length(at2)/length(ori)*sd(at2))
sdr_b<-sd(ori)-(length(bt1)/length(ori)*sd(bt1)+length(bt2)/length(ori)*sd(bt2))
sdr_a
## [1] 0.7702557
sdr_b
## [1] 1.041531

The method length() is used above to get the number of elements in a specific vector.

Larger SDR indicates greater reduction in standard deviation after splitting. Here we have split method 2 yielding greater SDR, so the tree splitting decision would prefer the second method, which is expected to produce more homogeneous sub-sets (children nodes), compared to method 1.

Now, the tree will be split under bt1 and bt2 following same rules (greater SDR wins). Assume we cannot split further (bt1 and bt2 are terminal nodes). The observations classified into bt1 will be predicted with \(mean(bt1)=3\) and those classified as bt2 with \(mean(bt2)=6.75\).

5 Bayesian Additive Regression Trees (BART)

Bayesian Additive Regression Trees (BART) represent sums of regression trees models that rely on boosting the constituent Bayesian regularized trees.

The R packages BayesTree and BART provide computational implementation of fitting BART models to data. In supervised setting where \(x\) and \(y\) represent the predictors and the outcome, the BART model is mathematically represented as: \[y=f(x)+\epsilon = \sum_{j=1}^m{f_j(x)} +\epsilon.\] More specifically, \[y_i=f(x_i)+\epsilon = \sum_{j=1}^m{f_j(x_i)} +\epsilon,\ \forall 1\leq i\leq n.\] The residuals are typically assumed to be white noise, \(\epsilon_i \sim N(o,\sigma^2), \ iid\). The function \(f\) represents the boosted ensemble of weaker regression trees, \(f_i=g( . | T_j, M_j)\), where \(T_j\) and \(M_j\) represent the \(j^{th}\) tree and the set of values, \(\mu_{k,j}\), assigned to each terminal node \(k\) in \(T_j\), respectively.

The BART model may be estimated via Gibbs sampling, e.g., Bayesian backfitting Markov chain Monte Carlo (MCMC) algorithm. For instance, iteratively sampling \((T_j ,M_j)\) and \(\sigma\), conditional on all other variables, \((x,y)\), for each \(j\) until meeting certain convergence criterion. Given \(\sigma\), conditional sampling of \((T_j ,M_j)\) may be accomplished via the partial residual \[\epsilon_{j_o} = \sum_{i=1}^n{\left (y_i - \sum_{j\not= j_o}^m{g( x_i | T_j, M_j)} \right )}.\]

For prediction on new data, \(X\), data-driven priors on \(\sigma\) and the parameters defining \((T_j ,M_j)\) may be used to allow sampling a model from the posterior distribution: \[p\left (\sigma, \{(T_1 ,M_1), (T_2 ,M_2), ..., (T_2 ,M_2)\} | X, Y \right )=\] \[= p(\sigma) \prod_{(tree)j=1}^m{\left ( \left ( \prod_{(node)k} {p(\mu_{k,j}|T_j)} \right )p(T_j) \right )} .\] In this posterior factorization, model regularization is achieved by four criteria:

  • Enforcing reasonable marginal probabilities, \(p(T_j)\), to ensure that the probability of a depth \(d\) node in the tree \(T_j\) has children decrease as \(d\) increases. That is, the probability a current bottom node, at depth \(l\), is split into a left and right child nodes is \(\frac{\alpha}{(1+l)^{\beta}}\), where the base (\(\alpha\)) and power (\(\beta\)) parameters are selected to optimize the fit (including regularization);
  • For each interior (non-terminal) node, the distribution on the splitting variable assignments is uniform over the range of values taken by a variable;
  • For each interior (non-terminal) node, the distribution on the splitting rule assignment, conditional on the splitting variable, is uniform over the discrete set of splitting values;
  • All other priors are chosen as \(p(\mu_{k,j} |T_j) = N(\mu_{k,j} |\mu, \sigma)\) and \(p(\sigma)\) where \(\sigma^2\) is inverse chi-square distributed. To facilitate the calculations, this introduces prior conjugate structure with the corresponding hyper-parameters estimated using the observed data.

Using the BART model to forecast a response corresponding to newly observed data \(x\) is achieved by using one individual (or ensembling/averaging multiple) prediction models near algorithmic convergence.

The BART algorithm involves three steps:

  • Initialize a prior on the model parameters \((f,\sigma)\), where \(f=\{f_i=g( . | T_j, M_j)\}_i\),
  • Run a Markov chain with state \((f,\sigma)\) where the stationary distribution is the posterior \(p \left ((f,\sigma)|Data=\{(x_i,y_i)\}_{i=1}^n\right )\),
  • Examine the draws as a representation of the full posterior. Even though \(f\) is complex and changes its dimensional structure, for a given \(x\), we can explore the marginals of \(\sigma\) and \(f(x)\) by selecting a set of data \(\{x_j\}_{j}\) and computing \(f(x_j)\). If \(f_l\) represents the \(l^{th}\) MCMC draw, then the homologous Bayesian tree structures at every draw will yield results of the same dimensions \(\left ( f_l(x_i), f_l(x_2), ...\right )\).

5.1 1D Simulation

This example illustrates a simple 1D BART simulation, based on \(h(x)=x^3\sin(x)\), where we can nicely plot the performance of the BART classifier.

# simulate training data
sig = 0.2  # sigma
func = function(x) {
  return (sin(x) * (x^(3)))
}

set.seed(1234)
n = 300
x = sort(2*runif(n)-1)        # define the input
y = func(x) + sig * rnorm(n)  # define the output

# xtest: values we want to estimate func(x) at; this is also our prior prediction for y.
xtest = seq(-pi, pi, by=0.2)

##plot simulated data
# plot(x, y, cex=1.0, col="gray")
# points(xtest, rep(0, length(xtest)), col="blue", pch=1, cex=1.5)
# legend("top", legend=c("Data", "(Conjugate) Flat Prior"),
#    col=c("gray", "blue"), lwd=c(2,2), lty=c(3,3), bty="n", cex=0.9, seg.len=3)

plot_ly(x=x, y=y, type="scatter", mode="markers", name="data") %>%
  add_trace(x=xtest, y=rep(0, length(xtest)), mode="markers", name="prior") %>%
  layout(title='(Conjugate) Flat Prior',
           xaxis = list(title="X", range = c(-1.3, 1.3)),
           yaxis = list(title="Y", range = c(-0.6, 1.6)),
           legend = list(orientation = 'h'))
# run the weighted BART (BART::wbart) on the simulated data  
# install.packages("BART")
library(BART)
set.seed(1234) # set seed for reproducibility of MCMC
# nskip=Number of burn-in MCMC iterations; ndpost=number of posterior draws to return
model_bart <- wbart(x.train=as.data.frame(x), y.train=y, x.test=as.data.frame(xtest), 
                    nskip=300, ndpost=1000, printevery=1000) 
## *****Into main of wbart
## *****Data:
## data:n,p,np: 300, 1, 32
## y1,yn: 0.572918, 1.073126
## x1,x[n*p]: -0.993435, 0.997482
## xp1,xp[np*p]: -3.141593, 3.058407
## *****Number of Trees: 200
## *****Number of Cut Points: 100 ... 100
## *****burn and ndpost: 300, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,0.032889,3.000000,0.018662
## *****sigma: 0.309524
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,1,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1300)
## done 1000 (out of 1300)
## time: 3s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
# result is a list containing the BART run 

# explore the BART model fit
names(model_bart)
##  [1] "sigma"           "yhat.train.mean" "yhat.train"      "yhat.test.mean" 
##  [5] "yhat.test"       "varcount"        "varprob"         "treedraws"      
##  [9] "proc.time"       "mu"              "varcount.mean"   "varprob.mean"   
## [13] "rm.const"
dim(model_bart$yhat.test)
## [1] 1000   32
# The (l,j) element of the matrix `yhat.test` represents the l^{th} draw of `func` evaluated at the j^{th} value of x.test
# A matrix with ndpost rows and nrow(x.train) columns. Each row corresponds to a draw f* from the posterior of f
# and each column corresponds to a row of x.train. The (i,j) value is f*(x) for the i^th kept draw of f 
# and the j^th row of x.train

# # plot the data, the BART Fit, and the uncertainty
# plot(x, y, cex=1.2, cex.axis=0.8, cex.lab=0.8, mgp=c(1.3,.3,0), tcl=-0.2, col="gray")
# lines(xtest, func(xtest), col="blue", lty=1, lwd=2)
# lines(xtest, apply(model_bart$yhat.test, 2, mean), col="green", lwd=2, lty=2) # show the mean of f(x_j)
# quant_marg = apply(model_bart$yhat.test, 2, quantile, probs=c(0.025, 0.975)) # plot the 2.5% and 97.5% quantiles
# lines(xtest, quant_marg[1,], col="red", lty=1, lwd=2)
# lines(xtest, quant_marg[2,], col="red", lty=1,lwd=2)
# legend("top", legend=c("Data", "True Signal","Posterior Mean","95% CI"),
#    col=c("black", "blue","green","red"), lwd=c(2, 2,2,2), lty=c(3, 1,1,1), bty="n", cex=0.9, seg.len=3)

quant_marg = apply(model_bart$yhat.test, 2, quantile, probs=c(0.025, 0.975)) # plot the 2.5% and 97.5% quantiles
plot_ly(x=x, y=y, type="scatter", mode="markers", name="Data") %>%
  add_trace(x=xtest, y=func(xtest), mode="markers", name="True Signal") %>%
  add_trace(x=xtest, y=apply(model_bart$yhat.test, 2, mean), mode="lines", name="Posterior Mean") %>%
  add_trace(x=xtest, y=apply(model_bart$yhat.test, 2, quantile, probs=c(0.025, 0.975)),
            mode="markers", name="95% CI") %>%
  add_trace(x=xtest, y=quant_marg[1,], mode="lines", name="Lower Band") %>%
  add_trace(x=xtest, y=quant_marg[2,], mode="lines", name="Upper Band") %>%
  layout(title='BART Model  (n=300)',
           xaxis = list(title="X", range = c(-1.3, 1.3)),
           yaxis = list(title="Y", range = c(-0.6, 1.6)),
           legend = list(orientation = 'h'))
names(model_bart)
##  [1] "sigma"           "yhat.train.mean" "yhat.train"      "yhat.test.mean" 
##  [5] "yhat.test"       "varcount"        "varprob"         "treedraws"      
##  [9] "proc.time"       "mu"              "varcount.mean"   "varprob.mean"   
## [13] "rm.const"
dim(model_bart$yhat.train)
## [1] 1000  300
summary(model_bart$yhat.train.mean-apply(model_bart$yhat.train, 2, mean))
##       Min.    1st Qu.     Median       Mean    3rd Qu.       Max. 
## -4.441e-16 -8.327e-17  6.939e-18 -8.130e-18  9.714e-17  5.551e-16
summary(model_bart$yhat.test.mean-apply(model_bart$yhat.test, 2, mean))
##       Min.    1st Qu.     Median       Mean    3rd Qu.       Max. 
## -2.776e-16  1.535e-16  3.331e-16  2.661e-16  4.441e-16  4.441e-16
# yhat.train(test).mean: Average the draws to get the estimate of the posterior mean of func(x)

If we increase the sample size, n, the computational complexity increases and the BART model bounds should get tighter.

n = 3000 # 300 --> 3,000
set.seed(1234)
x = sort(2*runif(n)-1)
y = func(x) + sig*rnorm(n)

model_bart_2 <- wbart(x.train=as.data.frame(x), y.train=y, x.test=as.data.frame(xtest), 
                      nskip=300, ndpost=1000, printevery=1000)
## *****Into main of wbart
## *****Data:
## data:n,p,np: 3000, 1, 32
## y1,yn: 1.021931, 0.626872
## x1,x[n*p]: -0.999316, 0.998606
## xp1,xp[np*p]: -3.141593, 3.058407
## *****Number of Trees: 200
## *****Number of Cut Points: 100 ... 100
## *****burn and ndpost: 300, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,0.035536,3.000000,0.017828
## *****sigma: 0.302529
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,1,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1300)
## done 1000 (out of 1300)
## time: 15s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
# plot(x, y, cex=1.2, cex.axis=0.8, cex.lab=0.8, mgp=c(1.3,.3,0), tcl=-0.2, col="gray")
# lines(xtest, func(xtest), col="blue", lty=1, lwd=2)
# lines(xtest, apply(model_bart_2$yhat.test, 2, mean), col="green", lwd=2, lty=2) # show the mean of f(x_j)
# quant_marg = apply(model_bart_2$yhat.test, 2, quantile, probs=c(0.025, 0.975)) # plot the 2.5% and 97.5% CI
# lines(xtest, quant_marg[1,], col="red", lty=1, lwd=2)
# lines(xtest, quant_marg[2,], col="red", lty=1,lwd=2)
# legend("top",legend=c("Data (n=3,000)", "True Signal", "Posterior Mean","95% CI"),
#    col=c("gray", "blue","green","red"), lwd=c(2, 2,2,2), lty=c(3, 1,1,1), bty="n", cex=0.9, seg.len=3)

quant_marg2 = apply(model_bart_2$yhat.test, 2, quantile, probs=c(0.025, 0.975)) # plot the 2.5% and 97.5% CI
plot_ly(x=x, y=y, type="scatter", mode="markers", name="Data") %>%
  add_trace(x=xtest, y=func(xtest), mode="markers", name="True Signal") %>%
  add_trace(x=xtest, y=apply(model_bart_2$yhat.test, 2, mean), mode="lines", name="Posterior Mean") %>%
  add_trace(x=xtest, y=apply(model_bart_2$yhat.test, 2, quantile, probs=c(0.025, 0.975)),
            mode="markers", name="95% CI") %>%
  add_trace(x=xtest, y=quant_marg2[1,], mode="lines", name="Lower Band") %>%
  add_trace(x=xtest, y=quant_marg2[2,], mode="lines", name="Upper Band") %>%
  layout(title='BART Model (n=3,000)',
           xaxis = list(title="X", range = c(-1.3, 1.3)),
           yaxis = list(title="Y", range = c(-0.6, 1.6)),
           legend = list(orientation = 'h'))

5.2 Higher-Dimensional Simulation

In this second example, we will simulate \(n=5,000\) cases and \(p=20\) features.

# simulate data
set.seed(1234)
n=5000; p=20
beta = 3*(1:p)/p
sig=1.0
X = matrix(rnorm(n*p), ncol=p) # design matrix)
y = 10 + X %*% matrix(beta, ncol=1) + sig*rnorm(n) # outcome
y=as.double(y)

np=100000
Xp = matrix(rnorm(np*p), ncol=p)
set.seed(1234)
t1 <- system.time(model_bart_MD <- wbart(x.train=as.data.frame(X), y.train=y, x.test=as.data.frame(Xp),
                  nkeeptrain=200, nkeeptest=100, nkeeptestmean=500, nkeeptreedraws=100, printevery=1000)
      )
## *****Into main of wbart
## *****Data:
## data:n,p,np: 5000, 20, 100000
## y1,yn: -6.790191, 5.425931
## x1,x[n*p]: -1.207066, -3.452091
## xp1,xp[np*p]: -0.396587, -0.012973
## *****Number of Trees: 200
## *****Number of Cut Points: 100 ... 100
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,1.049775,3.000000,0.189323
## *****sigma: 0.985864
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,20,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 200,100,500,100
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 5,10,2,10
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 163s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 200,100,500,100
dim(model_bart_MD$yhat.train)
## [1]  200 5000
dim(model_bart_MD$yhat.test)
## [1]    100 100000
names(model_bart_MD$treedraws)
## [1] "cutpoints" "trees"
# str(model_bart_MD$treedraws$trees)

# The trees are stored in a long character string and there are 100 draws each consisting of 200 trees.

# To predict using the Multi-Dimensional BART model (MD)
t2 <- system.time({pred_model_bart_MD2 <- predict(model_bart_MD, as.data.frame(Xp), mc.cores=6)})
## *****In main of C++ for bart prediction
## tc (threadcount): 6
## number of bart draws: 100
## number of trees in bart sum: 200
## number of x columns: 20
## from x,np,p: 20, 100000
## ***using serial code
dim(pred_model_bart_MD2)
## [1]    100 100000
t1
##    user  system elapsed 
##  163.16    0.06  163.61
t2
##    user  system elapsed 
##   44.20    0.08   44.35
# pred_model_bart_MD2 has row dimension equal to the number of kept tree draws (100) and
# column dimension equal to the number of row in Xp (100,000).  

# Compare the BART predictions using 1K trees vs. 100 kept trees (very similar results)
# plot(model_bart_MD$yhat.test.mean, apply(pred_model_bart_MD2, 2, mean),
#       xlab="BART Prediction using 1,000 Trees", ylab="BART Prediction using 100 Kept Trees")
# abline(0,1, col="red", lwd=2)

plot_ly() %>%
  add_trace(x = c(-30,50), y = c(-30,50), type="scatter", mode="lines",
        line = list(width = 4), name="Consistent BART Prediction (1,000 vs. 100 Trees)") %>%
  add_markers(x=model_bart_MD$yhat.test.mean, y=apply(pred_model_bart_MD2, 2, mean), 
              name="BART Prediction Mean Estimates", type="scatter", mode="markers") %>%
  layout(title='Scatter of BART Predictions (1,000 vs. 100 Trees)',
           xaxis = list(title="BART Prediction (1,000 Trees)"),
           yaxis = list(title="BART Prediction (100 Trees)"),
           legend = list(orientation = 'h'))
# Compare BART Prediction to a linear fit  
lm_func = lm(y ~ ., data.frame(X,y))
pred_lm = predict(lm_func, data.frame(Xp))

# plot(pred_lm, model_bart_MD$yhat.test.mean, xlab="Linear Model Predictions", ylab="BART Predictions",
#         cex=0.5, cex.axis=1.0, cex.lab=0.8, mgp=c(1.3,.3,0), tcl=-0.2)
# abline(0,1, col="red", lwd=2)

plot_ly() %>%
  add_markers(x=pred_lm, y=model_bart_MD$yhat.test.mean, 
              name="Consistent LM/BART Prediction", type="scatter", mode="markers") %>%
  add_trace(x = c(-30,50), y = c(-30,50), type="scatter", mode="lines",
        line = list(width = 4), name="LM vs. BART Prediction") %>%
  layout(title='Scatter of Linear Model vs. BART Predictions',
           xaxis = list(title="Linear Model Prediction"),
           yaxis = list(title="BART Prediction"),
           legend = list(orientation = 'h'))

5.3 Heart Attack Hospitalization Case-Study

Let’s use BART to model the heart attack dataset (CaseStudy12_ AdultsHeartAttack_Data.csv). The data includes about 150 observations and 8 features, including hospital charges (CHARGES), which will be used as a response variable.

heart_attack<-read.csv("https://umich.instructure.com/files/1644953/download?download_frd=1", 
                       stringsAsFactors = F)
str(heart_attack)
## 'data.frame':    150 obs. of  8 variables:
##  $ Patient  : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ DIAGNOSIS: int  41041 41041 41091 41081 41091 41091 41091 41091 41041 41041 ...
##  $ SEX      : chr  "F" "F" "F" "F" ...
##  $ DRG      : int  122 122 122 122 122 121 121 121 121 123 ...
##  $ DIED     : int  0 0 0 0 0 0 0 0 0 1 ...
##  $ CHARGES  : chr  "4752" "3941" "3657" "1481" ...
##  $ LOS      : int  10 6 5 2 1 9 15 15 2 1 ...
##  $ AGE      : int  79 34 76 80 55 84 84 70 76 65 ...
# convert the CHARGES (independent variable) to numerical form. 
# NA's are created so let's remain only the complete cases 
heart_attack$CHARGES <- as.numeric(heart_attack$CHARGES)
heart_attack <- heart_attack[complete.cases(heart_attack), ]
heart_attack$gender <- ifelse(heart_attack$SEX=="F", 1, 0)
heart_attack <- heart_attack[, -c(1,2,3)]
dim(heart_attack); colnames(heart_attack)
## [1] 148   6
## [1] "DRG"     "DIED"    "CHARGES" "LOS"     "AGE"     "gender"
x.train <- as.matrix(heart_attack[ , -3]) # x training, excluding the charges (output)
y.train = heart_attack$CHARGES # y=output for modeling (BART, lm, lasso, etc.)

# Data should be standardized for all model-based predictions (e.g., lm, lasso/glmnet), but
# this is not critical for BART

# We'll just do some random train/test splits and report the out of sample performance of BART and lasso
RMSE <- function(y, yhat) {
  return(sqrt(mean((y-yhat)^2)))
}

nd <- 10 # number of train/test splits (ala CV validation)
n <- length(y.train)
ntrain <- floor(0.8*n) # 80:20 train:test split each time

RMSE_BART <- rep(0, nd) # initialize BART and LASSO RMSE vectors 
RMSE_LASSO <- rep(0, nd) 

pred_BART <- matrix(0.0, n-ntrain,nd) # Initialize the BART and LASSO out-of-sample predictions
pred_LASSO <- matrix(0.0, n-ntrain,nd) 

In Chapter 17, we will learn more about LASSO regularized linear modeling. Now, let’s use the glmnet::glmnet() method to fit a LASSO model and compare it to BART using the Heart Attack hospitalization case-study.

library(glmnet)
for(i in 1:nd) {
   set.seed(1234*i)  
   # train/test split index
   train_ind <- sample(1:n, ntrain)
   
   # Outcome (CHARGES)
   yTrain <- y.train[train_ind]; yTest <- y.train[-train_ind]
   
   # Features for BART
   xBTrain <- x.train[train_ind, ]; xBTest <- x.train[-train_ind, ]
   
   # Features for LASSO (scale)
   xLTrain <- apply(x.train[train_ind, ], 2, scale)
   xLTest <- apply(x.train[-train_ind, ], 2, scale)

   # BART: parallel version of mc.wbart, same arguments as in wbart
   # model_BART <- mc.wbart(xBTrain, yTrain, xBTest, mc.cores=6, keeptrainfits=FALSE)
   model_BART <- wbart(xBTrain, yTrain, xBTest, printevery=1000)
   
   # LASSO
   cv_LASSO <- cv.glmnet(xLTrain, yTrain, family="gaussian", standardize=TRUE)
   best_lambda <- cv_LASSO$lambda.min
   model_LASSO <- glmnet(xLTrain, yTrain, family="gaussian", lambda=c(best_lambda), standardize=TRUE)

   #get predictions on testing data
   pred_BART1 <- model_BART$yhat.test.mean
   pred_LASSO1 <- predict(model_LASSO, xLTest, s=best_lambda, type="response")[, 1]
  
    #store results
   RMSE_BART[i] <- RMSE(yTest, pred_BART1); pred_BART[, i] <- pred_BART1
   RMSE_LASSO[i] <- RMSE(yTest, pred_LASSO1);   pred_LASSO[, i] <- pred_LASSO1; 
}
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 2105.135593, -2641.864407
## x1,x[n*p]: 121.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2228477.133405
## *****sigma: 3382.354604
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: -3870.449153, 5404.550847
## x1,x[n*p]: 122.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 0.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2261599.844262
## *****sigma: 3407.398506
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 2659.855932, -1749.144068
## x1,x[n*p]: 121.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 0.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2182151.820220
## *****sigma: 3347.013986
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 714.000000, -4103.000000
## x1,x[n*p]: 121.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2310883.132530
## *****sigma: 3444.324312
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 3229.305085, 1884.305085
## x1,x[n*p]: 122.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 0.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2169347.152131
## *****sigma: 3337.179552
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 2115.474576, 1998.474576
## x1,x[n*p]: 121.000000, 1.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2209447.645085
## *****sigma: 3367.882283
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: -1255.525424, -3624.525424
## x1,x[n*p]: 123.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 0.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2177441.668674
## *****sigma: 3343.399788
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 3795.457627, -388.542373
## x1,x[n*p]: 121.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,288.534922,3.000000,1972180.926129
## *****sigma: 3181.913893
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 3s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: -3943.779661, -524.779661
## x1,x[n*p]: 123.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2184631.918632
## *****sigma: 3348.915451
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000
## *****Into main of wbart
## *****Data:
## data:n,p,np: 118, 5, 30
## y1,yn: 5245.610169, 4149.610169
## x1,x[n*p]: 121.000000, 0.000000
## xp1,xp[np*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 100, 1000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,289.984491,3.000000,2281761.566623
## *****sigma: 3422.552953
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
## *****printevery: 1000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 1100)
## done 1000 (out of 1100)
## time: 2s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 1000,1000,1000,1000

Plot BART vs. LASSO predictions.

# compare the out of sample RMSE measures
# qqplot(RMSE_BART, RMSE_LASSO)
# abline(0, 1, col="red", lwd=2)

plot_ly() %>%
  add_markers(x=RMSE_BART, y=RMSE_LASSO, 
              name="", type="scatter", mode="markers") %>%
  add_trace(x = c(2800,4000), y = c(2800,4000), type="scatter", mode="lines",
        line = list(width = 4), name="") %>%
  layout(title='Scatter of Linear Model vs. BART RMSE',
           xaxis = list(title="RMSE (BART)"),
           yaxis = list(title="RMSE (Linear Model)"),
           legend = list(orientation = 'h'))
# Next compare the out of sample predictions
model_lm <- lm(B ~ L, data.frame(B=as.double(pred_BART), L=as.double(pred_LASSO)))
x1 <- c(2800,9000)
y1 <- 2373.0734628 + 0.5745497*x1

# plot(as.double(pred_BART), as.double(pred_LASSO),xlab="BART Predictions",ylab="LASSO Predictions", col="gray")
# abline(0, 1, col="red", lwd=2)
# abline(model_lm$coef, col="blue", lty=2, lwd=3)
# legend("topleft",legend=c("Scatterplot Predictions (BART vs. LASSO)", "Ideal Agreement", "LM (BART~LASSO)"),
#    col=c("gray", "red","blue"), lwd=c(2,2,2), lty=c(3,1,1), bty="n", cex=0.9, seg.len=3)

plot_ly() %>%
  add_markers(x=as.double(pred_BART), y=as.double(pred_LASSO), 
              name="BART Predictions vs. Observed Scatter", type="scatter", mode="markers") %>%
  add_trace(x = c(2800,9000), y = c(2800,9000), type="scatter", mode="lines",
        line = list(width = 4), name="Ideal Agreement") %>%
  add_trace(x = x1, y = y1, type="scatter", mode="lines",
        line = list(width = 4), name="LM (BART ~ LASSO)") %>%
  layout(title='Scatterplot Predictions (BART vs. LASSO)',
           xaxis = list(title="BART Predictions"),
           yaxis = list(title="LASSO Predictions"),
           legend = list(orientation = 'h'))

If the default prior estimate (sigest of the error variance (\(\sigma^2\)) is inverted chi-squared, i.e., standard conditionally conjugate prior) yields reasonable results, we can try longer BART runs (ndpost=5000). Mind the stable distribution of the \(\hat{\sigma}^2\) (y-axis) with respect to the number of posterior draws (x-axis).

model_BART_long <- wbart(x.train, y.train, nskip=1000, ndpost=5000, printevery=5000) 
## *****Into main of wbart
## *****Data:
## data:n,p,np: 148, 5, 0
## y1,yn: -712.837838, 2893.162162
## x1,x[n*p]: 122.000000, 1.000000
## *****Number of Trees: 200
## *****Number of Cut Points: 2 ... 1
## *****burn and ndpost: 1000, 5000
## *****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,290.550176,3.000000,2178098.906176
## *****sigma: 3343.904335
## *****w (weights): 1.000000 ... 1.000000
## *****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,5,0
## *****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 5000,5000,5000,5000
## *****printevery: 5000
## *****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1
## 
## MCMC
## done 0 (out of 6000)
## done 5000 (out of 6000)
## time: 11s
## check counts
## trcnt,tecnt,temecnt,treedrawscnt: 5000,0,0,5000
# plot(model_BART_long$sigma, xlab="Number of Posterior Draws Returned")

plot_ly() %>%
  add_markers(x=c(1:length(model_BART_long$sigma)), y=model_BART_long$sigma, 
              name="BART vs. LASSO Scatter", type="scatter", mode="markers") %>%
  layout(title='Scatterplot BART Sigma (post burn in draws of sigma)',
           xaxis = list(title="Number of Posterior Draws Returned"),
           yaxis = list(title="model_BART_long$sigma"),
           legend = list(orientation = 'h'))
# plot(model_BART_long$yhat.train.mean, y.train, xlab="BART Predicted Charges", ylab="Observed Charges",
#      main=sprintf("Correlation (Observed,Predicted)=%f", 
#                          round(cor(model_BART_long$yhat.train.mean, y.train), digits=2)))
# abline(0, 1, col="red", lty=2)
# legend("topleft",legend=c("BART Predictions", "LM (BART~LASSO)"),
#    col=c("gray", "red","blue"), lwd=c(2,2,2), lty=c(3,1,1), bty="n", cex=0.9, seg.len=3)

plot_ly() %>%
  add_markers(x=model_BART_long$yhat.train.mean, y=y.train, 
              name="BART vs. LASSO Scatter", type="scatter", mode="markers") %>%
  add_trace(x = c(2800,9000), y = c(2800,9000), type="scatter", mode="lines",
        line = list(width = 4), name="LM (BART~LASSO)") %>%
  layout(title=sprintf("Observed vs. BART-Predicted Hospital Charges: Cor(BART,Observed)=%f", 
                         round(cor(model_BART_long$yhat.train.mean, y.train), digits=2)),
           xaxis = list(title="BART Predictions"),
           yaxis = list(title="Observed Values"),
           legend = list(orientation = 'h'))
ind <- order(model_BART_long$yhat.train.mean)

# boxplot(model_BART_long$yhat.train[ , ind], ylim=range(y.train), xlab="case", 
#             ylab="BART Hospitalization Charge Prediction Range")

caseIDs <- paste0("Case",rownames(heart_attack))
rowIDs <- paste0("", c(1:dim(model_BART_long$yhat.train)[1]))
colnames(model_BART_long$yhat.train) <- caseIDs
rownames(model_BART_long$yhat.train) <- rowIDs

df1 <- as.data.frame(model_BART_long$yhat.train[ , ind])
df2_wide <- as.data.frame(cbind(index=c(1:dim(model_BART_long$yhat.train)[1]), df1))
# colnames(df2_wide); dim(df2_wide)

df_long <- tidyr::gather(df2_wide, case, measurement, Case138:Case8)
# str(df_long )
# 'data.frame': 74000 obs. of  3 variables:
#  $ index      : int  1 2 3 4 5 6 7 8 9 10 ...
#  $ case       : chr  "Case138" "Case138" "Case138" "Case138" ...
#  $ measurement: num  5013 3958 4604 2602 2987 ...

actualCharges <- as.data.frame(cbind(cases=caseIDs, value=y.train))

plot_ly() %>%
  add_trace(data=df_long, y = ~measurement, color = ~case, type = "box") %>%
  add_trace(x=~actualCharges$cases, y=~actualCharges$value, type="scatter", mode="markers",
            name="Observed Charge", marker=list(size=20, color='green', line=list(color='yellow', width=2))) %>%
  layout(title="Box-and-whisker Plots across all 148 Cases (Highlighted True Charges)",
           xaxis = list(title="Cases"),
           yaxis = list(title="BART Hospitalization Charge Prediction Range"),
           showlegend=F)

The BART model indicates there is quite a bit of uncertainty in predicting the outcome (CHARGES) for each of the 148 cases using the other covariate features in the heart attack hospitalization data (DRG, DIED, LOS, AGE, gender).

6 Case study 2: Baseball Players (Take 2)

6.1 Step 2 - exploring and preparing the data

We still use the mlb dataset for this section. This dataset has 1034 observations. Let’s try to separate them into training and test datasets first.

set.seed(1234)
train_index <- sample(seq_len(nrow(mlb)), size = 0.75*nrow(mlb))
mlb_train<-mlb[train_index, ]
mlb_test<-mlb[-train_index, ]

Here we use a randomized split (75%-25%) to divide the training and test datasets.

6.2 Step 3 - training a model on the data

In R, the rpart::rpart() function provides an implementation for prediction using regression-tree modeling.

m<-rpart(dv~iv, data=mydata)

  • dv: dependent variable
  • iv: independent variable
  • mydata: training data containing dv and iv

We use two numerical features in the mlb data “01a_data.txt” Age and Height as features.

#install.packages("rpart")
library(rpart)
mlb.rpart<-rpart(Weight~Height+Age, data=mlb_train)
mlb.rpart
## n= 775 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 775 341952.400 201.8581  
##    2) Height< 74.5 500 164838.800 195.1200  
##      4) Height< 72.5 239  73317.920 189.7573  
##        8) Height< 70.5 63  15833.750 182.0635 *
##        9) Height>=70.5 176  52419.980 192.5114 *
##      5) Height>=72.5 261  78353.750 200.0307  
##       10) Age< 27.28 106  30455.070 194.2736 *
##       11) Age>=27.28 155  41982.840 203.9677 *
##    3) Height>=74.5 275 113138.700 214.1091  
##      6) Age< 30.015 201  86834.910 211.1592  
##       12) Height< 75.5 91  35635.850 205.8462 *
##       13) Height>=75.5 110  46505.170 215.5545  
##         26) Age< 24.795 23   7365.217 204.3478 *
##         27) Age>=24.795 87  35487.720 218.5172 *
##      7) Age>=30.015 74  19803.910 222.1216 *

The output contains rich information. split indicates the method to split; n is the number of observations that falls in this segment; yval is the predicted value if the test data falls into the specific segment (tree node decision cluster).

6.3 Visualizing decision trees

A fancy way of drawing the rpart decision tree is by rpart.plot() function under rpart.plot package.

# install.packages("rpart.plot")
library(rpart.plot)
rpart.plot(mlb.rpart, digits=3)

A more detailed graph can be obtained stating more options in the function.

rpart.plot(mlb.rpart, digits = 4, fallen.leaves = T, type=3, extra=101)

Also, you can use a more fancy tree plot from package rattle. From the fancy plot, you can observe the order and rules of splits.

library(rattle)
fancyRpartPlot(mlb.rpart, cex = 0.8)

6.4 Step 4 - evaluating model performance

Let’s make predictions with the prediction tree. predict() command is used in this section.

mlb.p<-predict(mlb.rpart, mlb_test)
summary(mlb.p)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   182.1   192.5   204.0   202.1   205.8   222.1
summary(mlb_test$Weight)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   150.0   186.5   200.0   201.3   219.0   260.0

After comparing five-number statistics for the predicted and true Weight, we can see that the model cannot precisely identify extreme cases such as the maximum. However, within IQR, the predictions are relatively accurate.

Correlation could be used to measure the correspondence of two equal length numeric variables. Let’s use cor() to examine the prediction accuracy.

cor(mlb.p, mlb_test$Weight)
## [1] 0.5508078

The predicted values (\(Weights\)) are moderately correlated with their true value counterparts. Chapter 13 provides additional strategies for model quality assessment.

6.5 Measuring performance with mean absolute error

To measure the distance between predicted value and the true value, we can use a measurement called mean absolute error (MAE). MAE follows the following formula: \[MAE=\frac{1}{n}\sum_{i=1}^{n}|pred_i-obs_i|,\] where the pred_i is the \(i^{th}\) predicted value and obs_i is the \(i^{th}\) observed value. Let’s make a corresponding MAE function in R and evaluate our model performance.

MAE<-function(obs, pred){
  mean(abs(obs-pred))
}
MAE(mlb_test$Weight, mlb.p)
## [1] 13.91322

This implies that on average, the difference between the predicted value and the observed value is \(15.1\). Considering that the Weight variable in our test dataset ranges from 150 to 260, the model performs well.

For comparison, suppose we used the most primitive method for prediction - the sample mean. How much larger would the MAE be?

mean(mlb_test$Weight)
## [1] 201.2934
MAE(mlb_test$Weight, mean(mlb_test$Weight))   # 202.556
## [1] 16.80094

This proves that the predictive decision tree is better than using the over all mean to predict every observation in the test dataset. However, it is not dramatically better. There might be room for improvement.

6.6 Step 5 - improving model performance

To improve the performance of our regression-tree forecasting, we are going to use a model tree instead of a regression tree. The RWeka::M5P() function implements the M5 algorithm and uses a similar syntax as rpart::rpart().

m<-M5P(dv~iv, data=mydata)

#install.packages("RWeka")

# Sometimes RWeka installations may be off a bit, see:
# http://stackoverflow.com/questions/41878226/using-rweka-m5p-in-rstudio-yields-java-lang-noclassdeffounderror-no-uib-cipr-ma

Sys.getenv("WEKA_HOME") # where does it point to? Maybe some obscure path? 
## [1] ""
# if yes, correct the variable:
Sys.setenv(WEKA_HOME="C:\\MY\\PATH\\WEKA_WPM")
library(RWeka)
# WPM("list-packages", "installed")

mlb.m5 <- M5P(Weight~Height+Age, data=mlb_train)
mlb.m5
## M5 pruned model tree:
## (using smoothed linear models)
## 
## Height <= 74.5 : 
## |   Height <= 72.5 : 
## |   |   Height <= 70.5 : 
## |   |   |   Age <= 29.68 : LM1 (41/81.063%)
## |   |   |   Age >  29.68 : LM2 (22/53.448%)
## |   |   Height >  70.5 : LM3 (176/82.16%)
## |   Height >  72.5 : 
## |   |   Age <= 27.28 : LM4 (106/80.695%)
## |   |   Age >  27.28 : 
## |   |   |   Age <= 31.415 : 
## |   |   |   |   Age <= 28.16 : LM5 (19/86.526%)
## |   |   |   |   Age >  28.16 : 
## |   |   |   |   |   Height <= 73.5 : LM6 (32/76.778%)
## |   |   |   |   |   Height >  73.5 : LM7 (38/73.92%)
## |   |   |   Age >  31.415 : LM8 (66/68.179%)
## Height >  74.5 : 
## |   Age <= 30.015 : 
## |   |   Height <= 75.5 : LM9 (91/94.209%)
## |   |   Height >  75.5 : 
## |   |   |   Age <= 24.795 : LM10 (23/85.192%)
## |   |   |   Age >  24.795 : LM11 (87/96.15%)
## |   Age >  30.015 : 
## |   |   Age <= 34.155 : 
## |   |   |   Age <= 32.24 : 
## |   |   |   |   Height <= 75.5 : LM12 (12/83.74%)
## |   |   |   |   Height >  75.5 : 
## |   |   |   |   |   Age <= 30.66 : LM13 (8/41.91%)
## |   |   |   |   |   Age >  30.66 : 
## |   |   |   |   |   |   Height <= 76.5 : LM14 (7/61.675%)
## |   |   |   |   |   |   Height >  76.5 : LM15 (7/30.928%)
## |   |   |   Age >  32.24 : LM16 (17/50.539%)
## |   |   Age >  34.155 : LM17 (23/70.498%)
## 
## LM num: 1
## Weight = 
##  1.1414 * Age 
##  + 91.6524
## 
## LM num: 2
## Weight = 
##  1.1414 * Age 
##  + 93.1647
## 
## LM num: 3
## Weight = 
##  0.6908 * Age 
##  + 140.1106
## 
## LM num: 4
## Weight = 
##  0.8939 * Age 
##  + 124.5246
## 
## LM num: 5
## Weight = 
##  3.6636 * Age 
##  - 68.1662
## 
## LM num: 6
## Weight = 
##  3.775 * Age 
##  - 82.96
## 
## LM num: 7
## Weight = 
##  5.7056 * Age 
##  - 129.4565
## 
## LM num: 8
## Weight = 
##  1.5624 * Age 
##  + 86.1193
## 
## LM num: 9
## Weight = 
##  0.9766 * Age 
##  + 125.9837
## 
## LM num: 10
## Weight = 
##  1.713 * Age 
##  + 58.6815
## 
## LM num: 11
## Weight = 
##  1.2048 * Age 
##  + 113.3309
## 
## LM num: 12
## Weight = 
##  3.4328 * Age 
##  - 37.4776
## 
## LM num: 13
## Weight = 
##  3.725 * Age 
##  - 108.393
## 
## LM num: 14
## Weight = 
##  3.9341 * Age 
##  - 112.8784
## 
## LM num: 15
## Weight = 
##  3.9341 * Age 
##  - 112.1999
## 
## LM num: 16
## Weight = 
##  2.7666 * Age 
##  + 20.46
## 
## LM num: 17
## Weight = 
##  1.9226 * Age 
##  + 59.2876
## 
## Number of Rules : 17

Instead of using segment averages to predict, the M5 model uses a linear regression (LM1) as the terminal node. In some datasets with more variables, M5P could give us multiple linear models under different terminal nodes.

Much like the general regression trees, M5 builds tree-based models. The difference is that regression trees produce univariate forecasts (values) at each terminal node, whereas M5 model-based regression trees generate multivariate linear models at each node. These model-based forecasts represent piece-wise linear functional models that can be used to numerically estimate outcomes at every node based on very high dimensional data (rich feature spaces).

summary(mlb.m5)
## 
## === Summary ===
## 
## Correlation coefficient                 -0.1406
## Mean absolute error                     74.606 
## Root mean squared error                 90.2806
## Relative absolute error                445.8477 %
## Root relative squared error            429.796  %
## Total Number of Instances              775
mlb.p.m5<-predict(mlb.m5, mlb_test)
summary(mlb.p.m5)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   3.917 129.375 146.584 132.213 155.216 168.107
cor(mlb.p.m5, mlb_test$Weight)
## [1] -0.2677442
MAE(mlb_test$Weight, mlb.p.m5)
## [1] 69.12329

We can use summary(mlb.m5) to report some rough diagnostic statistics of the model. Notice that the correlation and MAE for the M5 model are better compared to the results of the previous rpart() model.

7 Practice Problem: Heart Attack Data

Let’s use the heart attack dataset for practice.

heart_attack<-read.csv("https://umich.instructure.com/files/1644953/download?download_frd=1", stringsAsFactors = F)
str(heart_attack)
## 'data.frame':    150 obs. of  8 variables:
##  $ Patient  : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ DIAGNOSIS: int  41041 41041 41091 41081 41091 41091 41091 41091 41041 41041 ...
##  $ SEX      : chr  "F" "F" "F" "F" ...
##  $ DRG      : int  122 122 122 122 122 121 121 121 121 123 ...
##  $ DIED     : int  0 0 0 0 0 0 0 0 0 1 ...
##  $ CHARGES  : chr  "4752" "3941" "3657" "1481" ...
##  $ LOS      : int  10 6 5 2 1 9 15 15 2 1 ...
##  $ AGE      : int  79 34 76 80 55 84 84 70 76 65 ...

To begin with, we need to convert the CHARGES (independent variable) to numerical form. NA’s are created so let’s remain only the complete cases as mentioned in the beginning of this chapter. Also, let’s create a gender variable as an indicator for female patients using ifelse() and delete the previous SEX column.

heart_attack$CHARGES<-as.numeric(heart_attack$CHARGES)
heart_attack<-heart_attack[complete.cases(heart_attack), ]
heart_attack$gender<-ifelse(heart_attack$SEX=="F", 1, 0)
heart_attack<-heart_attack[, -3]

Now we can build a model tree using M5P() with all the features in the model. As usual, we need to separate the heart_attack data in to training and test datasets (use the 75%-25% way of separation).

After using the model to predict CHARGES in the test dataset we can obtain the following correlation and MAE.

## [1] -0.1087857
## [1] 3297.111

We can see that the predicted values and observed values are strongly correlated. In terms of MAE, it may seem very large at first glance.

range(ha_test$CHARGES)
## [1]  1354 12403
# 17137- 815
# 2867.884/16322

However, the test data itself has a wide range and the MAE is within 20% of the range. With only 148 observations, the model did a fairly good job in prediction. Can you reproduce or perhaps improve these results?

Try to replicate these results with other data from the list of our Case-Studies.

SOCR Resource Visitor number Web Analytics SOCR Email