Newer
Older
notebooks / ccn2019-correct.R
#==================================================#
# model the "correct" column

library(here)
library(tidyverse)
library(caret)
library(inspectdf)
library(pls)

#devtools::install_github("sachsmc/plotROC")
library(plotROC)


load(here("notebooks/data/nback_seqs.Rd"))

set.seed(42)

seqs.imputed <- seqs %>% 
  filter(!is.na(correct), !is.na(rt)) %>%
  mutate(correct=factor(correct,labels=c("INCORRECT","CORRECT")))

inspect_cat(seqs.imputed)
inspect_num(seqs.imputed)

seqs.dummy <- predict(dummyVars(~.,data=seqs.imputed),seqs.imputed)


train_indexes <- createDataPartition(seqs.imputed$correct,
                                     times = 1,
                                     p = 0.7,
                                     list = F)

train_data <- seqs.imputed[train_indexes,]
test_data <- seqs.imputed[-train_indexes,]

control <- trainControl(
  method = "repeatedcv",
  number = 5,
  repeats = 2,
  classProbs = T,
  verboseIter = T,
  savePredictions = T,
  sampling = "down",
  selectionFunction = "oneSE"
)

pls.new_model <- train(
  correct ~ .-a-al-dp-cr-rt,
  data = train_data,
  method = "pls",
  metric = "Accuracy",
  tuneLength = 20,
  preProcess = c("zv","center","scale"),
  trControl = control
)

plot(pls.new_model)
plot(varImp(pls.new_model), main="Variables Importance for Correctness (New Model)")

pls.common_model <- train(
  correct ~ .-a-al-dp-cr-rt-tl-ul-sl-s-ll-vl-l,
  data = train_data,
  method = "pls",
  metric = "Accuracy",
  tuneLength = 20,
  preProcess = c("zv","center","scale"),
  trControl = control
)

plot(pls.common_model)
plot(varImp(pls.common_model), main="Variable Importance for Correctness (Common Model)")


#trellis.par.set(caretTheme())
#densityplot(pls.new_model, pch = "|")
#densityplot(pls.common_model, pch = "|")

# Compile models and compare performance
pls.models <- resamples(list(new = pls.new_model, common = pls.common_model))
#DEBUG summary(pls.models)
#DEBUG dotplot(pls.models)
#DEBUG diffValues <- diff(resamps)
bwplot(pls.models, metric = "Accuracy", layout=c(1,1), main="Correctness Model Performance")


pls.new_predicted <- predict(pls.new_model, test_data, type="raw")
pls.new_predicted_prob <- predict(pls.new_model, test_data, type="prob")
pls.common_predicted <- predict(pls.common_model, test_data, type="raw")
pls.common_predicted_prob <- predict(pls.common_model, test_data, type="prob")

confusionMatrix(pls.new_predicted, test_data$correct)
confusionMatrix(pls.common_predicted, test_data$correct)

library(pROC)
par(pty="s")

roc(test_data$correct,
    pls.new_predicted_prob$CORRECT,
    plot = T,
    legacy.axes=T,
    lwd=4,
    col="black",
    print.auc.y = 45,
    percent = T,
    print.auc=T)

plot.roc(test_data$correct,
         pls.common_predicted_prob$CORRECT,
         legacy.axes=T,
         lwd=4,
         col="darkgray",
         print.auc=T,
         percent = T,
         print.auc.y = 40,
         lty = 3,
         add=T)

legend(100,100, legend=c("New Model", "Common Model"),
       col=c("black", "darkgray"), lty=c(1,3),lwd=3, cex=0.8)

# requires plotROC package
#DEBUG ggplot(pls.common_model, aes(d = pred$obs, m = pred$CORRECT)) +
#DEBUG   geom_roc()