#==================================================# # model the "correct" column library(here) library(tidyverse) library(caret) library(inspectdf) library(pls) #devtools::install_github("sachsmc/plotROC") library(plotROC) load(here("notebooks/data/nback_seqs.Rd")) set.seed(42) seqs.imputed <- seqs %>% filter(!is.na(correct), !is.na(rt)) %>% mutate(correct=factor(correct,labels=c("INCORRECT","CORRECT"))) inspect_cat(seqs.imputed) inspect_num(seqs.imputed) seqs.dummy <- predict(dummyVars(~.,data=seqs.imputed),seqs.imputed) train_indexes <- createDataPartition(seqs.imputed$correct, times = 1, p = 0.7, list = F) train_data <- seqs.imputed[train_indexes,] test_data <- seqs.imputed[-train_indexes,] control <- trainControl( method = "repeatedcv", number = 5, repeats = 2, classProbs = T, verboseIter = T, savePredictions = T, sampling = "down", selectionFunction = "oneSE" ) pls.new_model <- train( correct ~ .-a-al-dp-cr-rt, data = train_data, method = "pls", metric = "Accuracy", tuneLength = 20, preProcess = c("zv","center","scale"), trControl = control ) plot(pls.new_model) plot(varImp(pls.new_model), main="Variables Importance for Correctness (New Model)") pls.common_model <- train( correct ~ .-a-al-dp-cr-rt-tl-ul-sl-s-ll-vl-l, data = train_data, method = "pls", metric = "Accuracy", tuneLength = 20, preProcess = c("zv","center","scale"), trControl = control ) plot(pls.common_model) plot(varImp(pls.common_model), main="Variable Importance for Correctness (Common Model)") #trellis.par.set(caretTheme()) #densityplot(pls.new_model, pch = "|") #densityplot(pls.common_model, pch = "|") # Compile models and compare performance pls.models <- resamples(list(new = pls.new_model, common = pls.common_model)) #DEBUG summary(pls.models) #DEBUG dotplot(pls.models) #DEBUG diffValues <- diff(resamps) bwplot(pls.models, metric = "Accuracy", layout=c(1,1), main="Correctness Model Performance") pls.new_predicted <- predict(pls.new_model, test_data, type="raw") pls.new_predicted_prob <- predict(pls.new_model, test_data, type="prob") pls.common_predicted <- predict(pls.common_model, test_data, type="raw") pls.common_predicted_prob <- predict(pls.common_model, test_data, type="prob") confusionMatrix(pls.new_predicted, test_data$correct) confusionMatrix(pls.common_predicted, test_data$correct) library(pROC) par(pty="s") roc(test_data$correct, pls.new_predicted_prob$CORRECT, plot = T, legacy.axes=T, lwd=4, col="black", print.auc.y = 45, percent = T, print.auc=T) plot.roc(test_data$correct, pls.common_predicted_prob$CORRECT, legacy.axes=T, lwd=4, col="darkgray", print.auc=T, percent = T, print.auc.y = 40, lty = 3, add=T) legend(100,100, legend=c("New Model", "Common Model"), col=c("black", "darkgray"), lty=c(1,3),lwd=3, cex=0.8) # requires plotROC package #DEBUG ggplot(pls.common_model, aes(d = pred$obs, m = pred$CORRECT)) + #DEBUG geom_roc()