1 00:00:04,610 --> 00:00:07,040 In CART, the value of minbucket can 2 00:00:07,040 --> 00:00:10,370 affect the model's out-of-sample accuracy. 3 00:00:10,370 --> 00:00:13,070 As we discussed earlier in the lecture, 4 00:00:13,070 --> 00:00:17,200 if minbucket is too small, over-fitting might occur. 5 00:00:17,200 --> 00:00:21,480 But if minbucket is too large, the model might be too simple. 6 00:00:21,480 --> 00:00:24,830 So how should we set this parameter value? 7 00:00:24,830 --> 00:00:28,390 We could select the value that gives the best testing set 8 00:00:28,390 --> 00:00:31,380 accuracy, but this isn't right. 9 00:00:31,380 --> 00:00:35,050 The idea of the testing set is to measure model performance 10 00:00:35,050 --> 00:00:38,210 on data the model has never seen before. 11 00:00:38,210 --> 00:00:40,160 By picking the value of minbucket 12 00:00:40,160 --> 00:00:42,790 to get the best test set performance, 13 00:00:42,790 --> 00:00:47,820 the testing set was implicitly used to generate the model. 14 00:00:47,820 --> 00:00:52,140 Instead, we'll use a method called K-fold Cross Validation, 15 00:00:52,140 --> 00:00:56,180 which is one way to properly select the parameter value. 16 00:00:56,180 --> 00:01:00,020 This method works by going through the following steps. 17 00:01:00,020 --> 00:01:02,960 First, we split the training set into k 18 00:01:02,960 --> 00:01:06,190 equally sized subsets, or folds. 19 00:01:06,190 --> 00:01:09,500 In this example, k equals 5. 20 00:01:09,500 --> 00:01:15,370 Then we select k - 1, or four folds, to estimate the model, 21 00:01:15,370 --> 00:01:18,760 and compute predictions on the remaining one fold, which 22 00:01:18,760 --> 00:01:22,080 is often referred to as the validation set. 23 00:01:22,080 --> 00:01:24,300 We build a model and make predictions 24 00:01:24,300 --> 00:01:28,770 for each possible parameter value we're considering. 25 00:01:28,770 --> 00:01:32,200 Then we repeat this for each of the other folds, 26 00:01:32,200 --> 00:01:34,479 or pieces of our training set. 27 00:01:34,479 --> 00:01:37,720 So we would build a model using folds 1, 2, 3, 28 00:01:37,720 --> 00:01:41,780 and 5 to make predictions on fold 4, 29 00:01:41,780 --> 00:01:46,229 and then we would build a model using folds 1, 2, 4, 30 00:01:46,229 --> 00:01:51,320 and 5 to make predictions on fold 3, etc. 31 00:01:51,320 --> 00:01:53,830 So ultimately, cross validation builds 32 00:01:53,830 --> 00:02:00,360 many models, one for each fold and possible parameter value. 33 00:02:00,360 --> 00:02:03,170 Then, for each candidate parameter value, 34 00:02:03,170 --> 00:02:05,490 and for each fold, we can compute 35 00:02:05,490 --> 00:02:07,780 the accuracy of the model. 36 00:02:07,780 --> 00:02:12,410 This plot shows the possible parameter values on the x-axis, 37 00:02:12,410 --> 00:02:16,120 and the accuracy of the model on the y-axis. 38 00:02:16,120 --> 00:02:20,390 This line shows the accuracy of our model on fold 1. 39 00:02:20,390 --> 00:02:23,790 We can also compute the accuracy of the model using 40 00:02:23,790 --> 00:02:27,300 each of the other folds as the validation sets. 41 00:02:27,300 --> 00:02:30,410 We then average the accuracy over the k 42 00:02:30,410 --> 00:02:33,370 folds to determine the final parameter 43 00:02:33,370 --> 00:02:35,900 value that we want to use. 44 00:02:35,900 --> 00:02:38,850 Typically, the behavior looks like this-- 45 00:02:38,850 --> 00:02:41,400 if the parameter value is too small, 46 00:02:41,400 --> 00:02:44,590 then the accuracy is lower, because the model is probably 47 00:02:44,590 --> 00:02:46,940 over-fit to the training set. 48 00:02:46,940 --> 00:02:50,110 But if the parameter value is too large, 49 00:02:50,110 --> 00:02:52,980 then the accuracy is also lower, because the model 50 00:02:52,980 --> 00:02:55,079 is too simple. 51 00:02:55,079 --> 00:02:59,340 In this case, we would pick a parameter value around six, 52 00:02:59,340 --> 00:03:02,770 because it leads to the maximum average accuracy 53 00:03:02,770 --> 00:03:04,330 over all parameter values. 54 00:03:06,970 --> 00:03:10,010 So far, we've used the parameter minbucket 55 00:03:10,010 --> 00:03:14,580 to limit our tree in R. When we use cross validation in R, 56 00:03:14,580 --> 00:03:18,329 we'll use a parameter called cp instead. 57 00:03:18,329 --> 00:03:21,010 This is the complexity parameter. 58 00:03:21,010 --> 00:03:24,270 It's like Adjusted R-squared for linear regression, 59 00:03:24,270 --> 00:03:28,110 and AIC for logistic regression, in that it measures 60 00:03:28,110 --> 00:03:31,850 the trade-off between model complexity and accuracy 61 00:03:31,850 --> 00:03:33,790 on the training set. 62 00:03:33,790 --> 00:03:37,170 A smaller cp value leads to a bigger tree, 63 00:03:37,170 --> 00:03:40,310 so a smaller cp value might over-fit the model 64 00:03:40,310 --> 00:03:41,910 to the training set. 65 00:03:41,910 --> 00:03:44,720 But a cp value that's too large might 66 00:03:44,720 --> 00:03:47,200 build a model that's too simple. 67 00:03:47,200 --> 00:03:49,880 Let's go to R, and use cross validation 68 00:03:49,880 --> 00:03:53,850 to select the value of cp for our CART tree. 69 00:03:53,850 --> 00:03:57,050 In our R console, let's try cross validation 70 00:03:57,050 --> 00:03:58,810 for our CART model. 71 00:03:58,810 --> 00:04:03,590 To do this, we need to install and load two new packages. 72 00:04:03,590 --> 00:04:08,450 First, we'll install the package "caret". 73 00:04:13,060 --> 00:04:16,250 You should see some lines run in your R console, 74 00:04:16,250 --> 00:04:18,930 and then when you're back to the blinking cursor, 75 00:04:18,930 --> 00:04:21,329 load the package with library(caret). 76 00:04:25,260 --> 00:04:30,020 Now, let's install the package "e1071". 77 00:04:30,020 --> 00:04:39,280 So again, install.packages("e1071"). 78 00:04:39,280 --> 00:04:42,730 Again, you should see some lines run in your R console, 79 00:04:42,730 --> 00:04:44,680 and when you're back to the cursor, 80 00:04:44,680 --> 00:04:51,340 load the package with library(e1071). 81 00:04:51,340 --> 00:04:55,590 Now, we'll define our cross validation experiment. 82 00:04:55,590 --> 00:04:59,220 First, we need to define how many folds we want. 83 00:04:59,220 --> 00:05:03,560 We can do this using the trainControl function. 84 00:05:03,560 --> 00:05:13,060 So we'll say numFolds = trainControl, 85 00:05:13,060 --> 00:05:18,420 and then in parentheses, method = "cv", 86 00:05:18,420 --> 00:05:25,690 for cross validation, and then number = 10, for 10 folds. 87 00:05:25,690 --> 00:05:28,920 Then we need to pick the possible values for our cp 88 00:05:28,920 --> 00:05:33,340 parameter, using the expand.grid function. 89 00:05:33,340 --> 00:05:42,770 So we'll call it cpGrid, and then use expand.grid, 90 00:05:42,770 --> 00:05:57,240 where the only argument is .cp = seq(0.01,0.5,0.01). 91 00:05:57,240 --> 00:05:59,400 This will define our cp parameters 92 00:05:59,400 --> 00:06:07,800 to test as numbers from 0.01 to 0.5, in increments of 0.01. 93 00:06:07,800 --> 00:06:11,340 Now, we're ready to perform cross validation. 94 00:06:11,340 --> 00:06:14,600 We'll do this using the train function, where 95 00:06:14,600 --> 00:06:17,170 the first argument is similar to that 96 00:06:17,170 --> 00:06:18,720 when we're building models. 97 00:06:18,720 --> 00:06:22,100 It's the dependent variable, Reverse, 98 00:06:22,100 --> 00:06:26,490 followed by a tilde symbol, and then the independent variables 99 00:06:26,490 --> 00:06:36,990 separated by plus signs-- Circuit + Issue + Petitioner + 100 00:06:36,990 --> 00:06:50,820 Respondent + LowerCourt + Unconst. 101 00:06:50,820 --> 00:06:55,750 Our data set here is Train, with a capital T, 102 00:06:55,750 --> 00:07:02,510 and then we need to add the arguments method = "rpart", 103 00:07:02,510 --> 00:07:05,840 since we want to cross validate a CART model, 104 00:07:05,840 --> 00:07:12,430 and then trControl = numFolds, the output 105 00:07:12,430 --> 00:07:20,750 of our trainControl function, and then tuneGrid = cpGrid, 106 00:07:20,750 --> 00:07:24,990 the output of the expand.grid function. 107 00:07:24,990 --> 00:07:28,210 If you hit Enter, it might take a little while, 108 00:07:28,210 --> 00:07:30,040 but after a few seconds, you should 109 00:07:30,040 --> 00:07:32,630 get a table describing the cross validation 110 00:07:32,630 --> 00:07:37,060 accuracy for different cp parameters. 111 00:07:37,060 --> 00:07:41,440 The first column gives the cp parameter that was tested, 112 00:07:41,440 --> 00:07:43,950 and the second column gives the cross validation 113 00:07:43,950 --> 00:07:46,920 accuracy for that cp value. 114 00:07:46,920 --> 00:07:51,390 The accuracy starts lower, and then increases, 115 00:07:51,390 --> 00:07:56,420 and then will start decreasing again, as we saw in the slides. 116 00:07:56,420 --> 00:07:59,920 At the bottom of the output, it says, 117 00:07:59,920 --> 00:08:04,370 "Accuracy was used to select the optimal model using the largest 118 00:08:04,370 --> 00:08:05,510 value. 119 00:08:05,510 --> 00:08:11,170 The final value used for the model was cp = 0.18." 120 00:08:11,170 --> 00:08:15,820 This is the cp value we want to use in our CART model. 121 00:08:15,820 --> 00:08:19,640 So now let's create a new CART model with this value of cp, 122 00:08:19,640 --> 00:08:22,700 instead of the minbucket parameter. 123 00:08:22,700 --> 00:08:28,650 We'll call this model StevensTreeCV, 124 00:08:28,650 --> 00:08:32,750 and we'll use the rpart function, like we did earlier, 125 00:08:32,750 --> 00:08:39,220 to predict Reverse using all of our independent variables: 126 00:08:39,220 --> 00:08:50,760 Circuit, Issue, Petitioner, Respondent, 127 00:08:50,760 --> 00:08:56,520 LowerCourt, and Unconst. 128 00:08:56,520 --> 00:09:05,160 Our data set here is Train, and then we want method = "class", 129 00:09:05,160 --> 00:09:08,980 since we're building a classification tree, and cp 130 00:09:08,980 --> 00:09:13,710 = 0.18. 131 00:09:13,710 --> 00:09:17,830 Now, let's make predictions on our test set using this model. 132 00:09:17,830 --> 00:09:22,620 We'll call our predictions PredictCV, 133 00:09:22,620 --> 00:09:24,720 and we'll use the predict function 134 00:09:24,720 --> 00:09:30,880 to make predictions using the model StevensTreeCV, 135 00:09:30,880 --> 00:09:38,750 the newdata set Test, and we want to add type = "class", 136 00:09:38,750 --> 00:09:41,930 so that we get class predictions. 137 00:09:41,930 --> 00:09:44,240 Now let's create our confusion matrix, 138 00:09:44,240 --> 00:09:48,800 using the table function, where we first give the true outcome, 139 00:09:48,800 --> 00:09:55,110 Test$Reverse, and then our predictions, PredictCV. 140 00:09:55,110 --> 00:10:01,460 So the accuracy of this model is 59 + 64, 141 00:10:01,460 --> 00:10:08,620 divided by the total number in this table, 59 + 18 + 29 + 142 00:10:08,620 --> 00:10:13,500 64, the total number of observations in our test set. 143 00:10:13,500 --> 00:10:17,980 So the accuracy of this model is 0.724. 144 00:10:17,980 --> 00:10:21,460 Remember that the accuracy of our previous CART model 145 00:10:21,460 --> 00:10:24,470 was 0.659. 146 00:10:24,470 --> 00:10:26,820 Cross validation helps us make sure we're 147 00:10:26,820 --> 00:10:29,360 selecting a good parameter value, 148 00:10:29,360 --> 00:10:31,340 and often this will significantly 149 00:10:31,340 --> 00:10:33,530 increase the accuracy. 150 00:10:33,530 --> 00:10:37,090 If we had already happened to select a good parameter value, 151 00:10:37,090 --> 00:10:40,190 then the accuracy might not of increased that much. 152 00:10:40,190 --> 00:10:42,580 But by using cross validation, we 153 00:10:42,580 --> 00:10:47,020 can be sure that we're selecting a smart parameter value.