Lecture 5: Supervised Methods



Yanfei Kang
yanfeikang@buaa.edu.cn

School of Economics and Management
Beihang University

Packages required in this lecture

Unsupervised methods

Clustering: working without known targets

Supervised methods

Classification: deciding how to assign (known) labels to an object.

Examples of classification problems

Some common classification methods

Logistic regression

Some common classification methods

K nearest neighbours (KNN)

Some common classification methods

Naive Bayes

Naive Bayes classifiers are especially useful for problems with many input variables, categorical input variables with a very large number of possible values, and text classification.

Some common classification methods

Decision trees

Some common classification methods

Support vector machines

Case study: Titanic disaster

Our data

Each record in the dataset describes a passenger. The attributes are defined as follows:

  1. PassengerId: Unique passenger identification number

  2. Survived: Did a passenger survive or not (0 = died, 1 = survived)

  3. Pclass: Passenger Class (1 = 1st; 2 = 2nd; 3 = 3rd)

  4. Name: Name of Passenger

  5. Sex: Sex of Passenger

  6. Age: Passenger Age

  7. SibSp: Number of Siblings/Spouses Aboard

  8. Parch: Number of Parents/Children Aboard

  9. Ticket: Ticket Number

  10. Fare: Passenger Fare

  11. Cabin: Cabin

  12. Embarked: Port of Embarkation (C = Cherbourg; Q = Queenstown; S = Southampton)

Data preparation

library(caret)
# load the CSV file from the local directory
datasetTrain <- read.csv("https://yanfei.site/docs/dpsa/train.csv", header=TRUE, sep = ",")
datasetTest <- read.csv("https://yanfei.site/docs/dpsa/test.csv", header=TRUE, sep = ",")

Our problem to solve is to predict the survival of the passengers in the test dataset.

Data summarization

# dimension
dim(datasetTrain)
## [1] 891  12
# attribute class type
sapply(datasetTrain, class)
##        ID  Survived    Pclass      Name       Sex       Age     SibSp 
## "integer" "integer" "integer"  "factor"  "factor" "numeric" "integer" 
##     Parch    Ticket      Fare     Cabin  Embarked 
## "integer"  "factor" "numeric"  "factor"  "factor"
head(datasetTrain)
##   ID Survived Pclass                                                Name
## 1  1        0      3                             Braund, Mr. Owen Harris
## 2  2        1      1 Cumings, Mrs. John Bradley (Florence Briggs Thayer)
## 3  3        1      3                              Heikkinen, Miss. Laina
## 4  4        1      1        Futrelle, Mrs. Jacques Heath (Lily May Peel)
## 5  5        0      3                            Allen, Mr. William Henry
## 6  6        0      3                                    Moran, Mr. James
##      Sex Age SibSp Parch           Ticket    Fare Cabin Embarked
## 1   male  22     1     0        A/5 21171  7.2500              S
## 2 female  38     1     0         PC 17599 71.2833   C85        C
## 3 female  26     0     0 STON/O2. 3101282  7.9250              S
## 4 female  35     1     0           113803 53.1000  C123        S
## 5   male  35     0     0           373450  8.0500              S
## 6   male  NA     0     0           330877  8.4583              Q
# remove redundant variables
datasetTrain <- datasetTrain[, c(-1, -4, -9, -11, -12)]

# summary
summary(datasetTrain)
##     Survived          Pclass          Sex           Age       
##  Min.   :0.0000   Min.   :1.000   female:314   Min.   : 0.42  
##  1st Qu.:0.0000   1st Qu.:2.000   male  :577   1st Qu.:20.12  
##  Median :0.0000   Median :3.000                Median :28.00  
##  Mean   :0.3838   Mean   :2.309                Mean   :29.70  
##  3rd Qu.:1.0000   3rd Qu.:3.000                3rd Qu.:38.00  
##  Max.   :1.0000   Max.   :3.000                Max.   :80.00  
##                                                NA's   :177    
##      SibSp           Parch             Fare       
##  Min.   :0.000   Min.   :0.0000   Min.   :  0.00  
##  1st Qu.:0.000   1st Qu.:0.0000   1st Qu.:  7.91  
##  Median :0.000   Median :0.0000   Median : 14.45  
##  Mean   :0.523   Mean   :0.3816   Mean   : 32.20  
##  3rd Qu.:1.000   3rd Qu.:0.0000   3rd Qu.: 31.00  
##  Max.   :8.000   Max.   :6.0000   Max.   :512.33  
## 
# Percentage of survived
table(datasetTrain$Survived)
## 
##   0   1 
## 549 342
prop.table(table(datasetTrain$Survived))
## 
##         0         1 
## 0.6161616 0.3838384
# convert survived to factor
datasetTrain[,1] <- as.factor((datasetTrain[,1]))
data <- datasetTrain
data[,3] <- as.numeric((data[,3]))
complete_cases <- complete.cases(data)

Data visualization

# survival bar chart
ggplot(datasetTrain, aes(x = Survived)) + geom_bar() + ggtitle("Survived Bar Chart")

# Pclass distribution
ggplot(datasetTrain, aes(x = Pclass)) + geom_bar() + ggtitle("Survived Bar Chart")

# barplot of males and females in each class
ggplot(datasetTrain, aes(x = factor(Pclass), fill = factor(Sex))) +
geom_bar(position = "dodge")

# barplot of males and females who survived in each class
ggplot(datasetTrain, aes(x = factor(Pclass), fill = factor(Sex))) +
geom_bar(position = "dodge") +
facet_grid(". ~ Survived")

library(corrplot)
corrplot(cor(data[complete_cases, 2:5]))

Build models

# Model evaluation methods
# 10-fold cross validation with 3 repeats
trainControl <- 
  trainControl(method = "repeatedcv", 
               number = 10, 
               repeats = 3)
metric <- "Accuracy"

# logistic regression
fit.glm <- train(Survived~.,
                 data=datasetTrain,
                 method="glm", 
                 metric=metric, 
                 na.action = na.exclude,
                 trControl=trainControl)

# KNN
fit.knn <- train(Survived~.,
                 data=datasetTrain,
                 method="knn", 
                 metric=metric, 
                 na.action = na.exclude,
                 trControl=trainControl)

# Naive Bayes
fit.nb <- train(Survived~.,
                data=datasetTrain,
                method="nb", 
                metric=metric, 
                na.action = na.exclude,
                trControl=trainControl)

# Decision tree
fit.cart <- train(Survived~.,
                  data=datasetTrain,
                  method="rpart",
                  metric=metric, 
                  na.action = na.exclude,
                  trControl=trainControl)

# SVM
fit.svm <- train(Survived~.,
                 data=datasetTrain,
                 method="svmRadial",
                 metric=metric, 
                 na.action = na.exclude,
                 trControl=trainControl)

Cross validation

Algorithm comparison

# Compare algorithms
results <- resamples(list(LG=fit.glm,
                          KNN=fit.knn,
                          CART=fit.cart,
                          NB=fit.nb,
                          SVM=fit.svm))
summary(results)
## 
## Call:
## summary.resamples(object = results)
## 
## Models: LG, KNN, CART, NB, SVM 
## Number of resamples: 30 
## 
## Accuracy 
##        Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
## LG   0.6806  0.7666 0.8028 0.8002  0.8415 0.9028    0
## KNN  0.6056  0.6620 0.7042 0.6975  0.7222 0.7887    0
## CART 0.7042  0.7474 0.7762 0.7801  0.8056 0.8873    0
## NB   0.6806  0.7526 0.7887 0.7811  0.8141 0.8732    0
## SVM  0.7222  0.8056 0.8310 0.8269  0.8451 0.9028    0
## 
## Kappa 
##        Min. 1st Qu. Median   Mean 3rd Qu.   Max. NA's
## LG   0.3323  0.5107 0.5893 0.5817  0.6673 0.7968    0
## KNN  0.1569  0.2929 0.3846 0.3621  0.4151 0.5557    0
## CART 0.3551  0.4558 0.5106 0.5245  0.5819 0.7668    0
## NB   0.3168  0.4861 0.5508 0.5422  0.6138 0.7275    0
## SVM  0.4161  0.5839 0.6399 0.6305  0.6750 0.7921    0
dotplot(results)

Finalize model

model <- train(Survived~., 
               data=datasetTrain,
               method="svmRadial",
               metric=metric, 
               na.action = na.exclude, 
               trControl=trainControl)
testData <- datasetTest[,c(-1, -8, -10, -11)]
testData$Age[is.na(testData$Age)] <- 0
testData$Fare[is.na(testData$Fare)] <- 0
predictions <- predict(model, testData)

References

  1. Nina Zumel and John Mount (2014). Data Science with R. Manning.
  2. The caret package in R.