diff --git a/dummy-vars-playground.R b/dummy-vars-playground.R index d956522..5a696db 100644 --- a/dummy-vars-playground.R +++ b/dummy-vars-playground.R @@ -8,6 +8,8 @@ rm(seqs) load(here("notebooks/data/nback_seqs.Rd")) +f <- as.formula("correct ~ stimulus + stimulus_type + n") + set.seed(42) # 1. dummy vars @@ -27,13 +29,14 @@ seqs.train.balanced <- seqs[train.indices,] -seqs.train <- ROSE(correct ~ ., data = seqs.train.balanced)$data +seqs.train <- seqs.train.balanced +# seqs.train <- ROSE(correct ~ ., data = seqs.train.balanced)$data -seqs.train.x <- model.matrix(correct ~ stimulus + stimulus_type + n, seqs.train)[,-1] +seqs.train.x <- model.matrix(f, seqs.train)[,-1] seqs.train.y <- seqs.train$correct seqs.test <- seqs[-train.indices,] -seqs.test.x <- model.matrix(correct ~ stimulus + stimulus_type + n, seqs.test)[,-1] +seqs.test.x <- model.matrix(f, seqs.test)[,-1] seqs.test.observed_y <- seqs.test$correct # model <- cv.glmnet(seqs.train.x, @@ -48,13 +51,15 @@ ctrl <- trainControl(method="cv", number=5, classProbs=T, + sampling = "up", + savePredictions = T, summaryFunction=twoClassSummary) # glmnet tune -tune <- expand.grid(alpha = 0:1, lambda = seq(0, 0.01, length = 100),ncomp=1:10) +tune <- expand.grid(alpha = 0:1, lambda = seq(0, 0.01, length = 100)) # pls tune -tune <- expand.grid(ncomp=1:10) +tune <- expand.grid(ncomp=1:6) model <- train(seqs.train.x, seqs.train.y, @@ -69,15 +74,31 @@ plot(model) seqs.test.y <- model %>% predict(seqs.test.x) +seqs.test.y_prob <- model %>% predict(seqs.test.x, type="prob") confusionMatrix(seqs.test.y, seqs.test.observed_y) +library(pROC) + +roc(seqs.test.observed_y, + seqs.test.y_prob$YES, + legacy.axes=T, + plot = T, + lwd=2, + col="black", + print.auc=T, + percent = T, + print.auc.y = 40, + print.auc.x = 55, + lty = 1, + of = "se", + boot.n = 100, + ci = T) + + # RT # data.frame( # RMSE = RMSE(y.test, seqs.test$correct), # Rsquare = R2(y.test, seqs.test$correct) # ) - -#dmy <- dummyVars(~.-stimulus-stimulus_type,seqs,fullRank = T) -#dmy.rt <- dummyVars(~correct+stimulus_type,seqs)