1 00:00:04,500 --> 00:00:06,720 In this video, we'll build a CART model 2 00:00:06,720 --> 00:00:09,010 to predict healthcare cost. 3 00:00:09,010 --> 00:00:14,040 First, let's make sure the packages rpart and rpart.plot 4 00:00:14,040 --> 00:00:16,340 are loaded with the library function. 5 00:00:19,440 --> 00:00:21,370 You should have already installed them 6 00:00:21,370 --> 00:00:24,380 in the previous lecture on predicting Supreme Court 7 00:00:24,380 --> 00:00:24,880 decisions. 8 00:00:29,020 --> 00:00:31,240 Now, let's build our CART model. 9 00:00:31,240 --> 00:00:35,200 We'll call it ClaimsTree. 10 00:00:35,200 --> 00:00:41,580 And we'll use the rpart function to predict bucket2009, 11 00:00:41,580 --> 00:00:54,550 using as independent variables: age, arthritis, alzheimers, 12 00:00:54,550 --> 00:01:13,920 cancer, copd, depression, diabetes, heart.failure, ihd, 13 00:01:13,920 --> 00:01:23,650 kidney, osteoporosis, and stroke. 14 00:01:23,650 --> 00:01:33,690 We'll also use bucket2008 and reimbursement2008. 15 00:01:33,690 --> 00:01:37,920 The data set we'll use to build our model is ClaimsTrain. 16 00:01:43,440 --> 00:01:49,400 And then we'll add the arguments, method = "class", 17 00:01:49,400 --> 00:01:59,370 since we have a classification problem here, and cp = 0.00005. 18 00:01:59,370 --> 00:02:02,380 Note that even though we have a multi-class classification 19 00:02:02,380 --> 00:02:06,270 problem here, we build our tree in the same way 20 00:02:06,270 --> 00:02:10,009 as a binary classification problem. 21 00:02:10,009 --> 00:02:12,320 So go ahead and hit Enter. 22 00:02:12,320 --> 00:02:15,490 The cp value we're using here was 23 00:02:15,490 --> 00:02:17,500 selected through cross-validation 24 00:02:17,500 --> 00:02:19,440 on the training set. 25 00:02:19,440 --> 00:02:22,170 We won't perform the cross-validation here, 26 00:02:22,170 --> 00:02:25,020 because it takes a significant amount of time 27 00:02:25,020 --> 00:02:27,730 on a data set of this size. 28 00:02:27,730 --> 00:02:33,390 Remember that we have almost 275,000 observations 29 00:02:33,390 --> 00:02:35,510 in our training set. 30 00:02:35,510 --> 00:02:37,760 But keep in mind that the R commands 31 00:02:37,760 --> 00:02:42,070 needed for cross-validation here are the same as those used 32 00:02:42,070 --> 00:02:44,870 in the previous lecture on predicting Supreme Court 33 00:02:44,870 --> 00:02:46,780 decisions. 34 00:02:46,780 --> 00:02:48,850 So now that our model's done, let's 35 00:02:48,850 --> 00:02:51,860 take a look at our tree with the prp function. 36 00:02:57,510 --> 00:02:59,670 It might take a while to load, because we 37 00:02:59,670 --> 00:03:01,930 have a huge tree here. 38 00:03:01,930 --> 00:03:04,730 This makes sense for a few reasons. 39 00:03:04,730 --> 00:03:07,740 One is the large number of observations in our training 40 00:03:07,740 --> 00:03:08,980 set. 41 00:03:08,980 --> 00:03:11,970 Another is that we have a five-class classification 42 00:03:11,970 --> 00:03:14,340 problem, so the classification is 43 00:03:14,340 --> 00:03:17,840 more complex than a binary classification case, 44 00:03:17,840 --> 00:03:20,680 like the one we saw in the previous lecture. 45 00:03:20,680 --> 00:03:25,530 The trees used by D2Hawkeye were also very large CART trees. 46 00:03:25,530 --> 00:03:28,290 While this hurts the interpretability of the model, 47 00:03:28,290 --> 00:03:31,850 it's still possible to describe each of the buckets of the tree 48 00:03:31,850 --> 00:03:34,260 according to the splits. 49 00:03:34,260 --> 00:03:37,780 So now, let's make predictions on the test set. 50 00:03:37,780 --> 00:03:45,079 So go back to your R console, and we'll call our predictions 51 00:03:45,079 --> 00:03:52,440 PredictTest, where we'll use the predict function for our model 52 00:03:52,440 --> 00:03:57,860 ClaimsTree, and our newdata is ClaimsTest. 53 00:04:02,680 --> 00:04:05,310 And we want to add type = "class" 54 00:04:05,310 --> 00:04:06,750 to get class predictions. 55 00:04:09,270 --> 00:04:11,490 And we can make our classification matrix 56 00:04:11,490 --> 00:04:14,360 on the test set to compute the accuracy. 57 00:04:14,360 --> 00:04:18,620 So we'll use the table function, where the actual outcomes are 58 00:04:18,620 --> 00:04:25,810 ClaimsTest$bucket2009, and our predictions are PredictTest. 59 00:04:29,100 --> 00:04:31,060 So to compute the accuracy, we need 60 00:04:31,060 --> 00:04:33,350 to add up the numbers on the diagonal 61 00:04:33,350 --> 00:04:36,600 and divide by the total number of observations in our test 62 00:04:36,600 --> 00:04:37,790 set. 63 00:04:37,790 --> 00:04:54,409 So we have 114141 + 16102 + 118 + 201 + 0. 64 00:04:54,409 --> 00:04:58,360 And we'll divide by the number of rows in ClaimsTest. 65 00:05:01,220 --> 00:05:06,430 So the accuracy of our model is 0.713. 66 00:05:06,430 --> 00:05:09,420 For the penalty error, we can use our penalty matrix 67 00:05:09,420 --> 00:05:11,670 like we did in the previous video. 68 00:05:11,670 --> 00:05:15,440 So scroll up to the classification matrix command 69 00:05:15,440 --> 00:05:22,560 and surround the table function by the as.matrix function, 70 00:05:22,560 --> 00:05:25,370 and then we'll multiply by PenaltyMatrix. 71 00:05:30,090 --> 00:05:34,010 So remember that this takes each entry in our classification 72 00:05:34,010 --> 00:05:37,770 matrix and multiplies it by the corresponding number 73 00:05:37,770 --> 00:05:40,310 in the penalty matrix. 74 00:05:40,310 --> 00:05:42,190 So now we just need to add up all 75 00:05:42,190 --> 00:05:47,800 of the numbers in this matrix by surrounding it by the sum 76 00:05:47,800 --> 00:05:51,840 function and then dividing by the total number 77 00:05:51,840 --> 00:05:54,930 of observations in our test set, or nrow(ClaimsTest). 78 00:05:59,700 --> 00:06:03,570 So our penalty error is 0.758. 79 00:06:03,570 --> 00:06:06,400 In the previous video, we saw that our baseline method 80 00:06:06,400 --> 00:06:12,330 had an accuracy of 68% and a penalty error of 0.74. 81 00:06:12,330 --> 00:06:14,510 So while we increased the accuracy, 82 00:06:14,510 --> 00:06:16,770 the penalty error also went up. 83 00:06:16,770 --> 00:06:18,100 Why? 84 00:06:18,100 --> 00:06:22,870 By default, rpart will try to maximize the overall accuracy, 85 00:06:22,870 --> 00:06:27,180 and every type of error is seen as having a penalty of one. 86 00:06:27,180 --> 00:06:31,360 Our CART model predicts 3, 4, and 5 so rarely 87 00:06:31,360 --> 00:06:35,450 because there are very few observations in these classes. 88 00:06:35,450 --> 00:06:37,530 So we don't really expect this model 89 00:06:37,530 --> 00:06:41,530 to do better on the penalty error than the baseline method. 90 00:06:41,530 --> 00:06:43,640 So how can we fix this? 91 00:06:43,640 --> 00:06:46,250 The rpart function allows us to specify 92 00:06:46,250 --> 00:06:48,500 a parameter called loss. 93 00:06:48,500 --> 00:06:50,330 This is the penalty matrix we want 94 00:06:50,330 --> 00:06:52,650 to use when building our model. 95 00:06:52,650 --> 00:06:56,930 So let's scroll back up to where we built our CART model. 96 00:06:56,930 --> 00:06:59,140 At the end of the rpart function, 97 00:06:59,140 --> 00:07:03,800 we'll add the argument params = list(loss=PenaltyMatrix). 98 00:07:13,120 --> 00:07:16,470 This is the name of the penalty matrix we created. 99 00:07:16,470 --> 00:07:21,150 Close the parentheses and hit Enter. 100 00:07:21,150 --> 00:07:23,960 So while our model is being built, 101 00:07:23,960 --> 00:07:27,270 let's think about what we expect to happen. 102 00:07:27,270 --> 00:07:30,000 If the rpart function knows that we'll 103 00:07:30,000 --> 00:07:33,070 be giving a higher penalty to some types of errors 104 00:07:33,070 --> 00:07:36,490 over others, it might choose different splits 105 00:07:36,490 --> 00:07:38,780 when building the model to minimize 106 00:07:38,780 --> 00:07:41,270 the worst types of errors. 107 00:07:41,270 --> 00:07:44,159 We'll probably get a lower overall accuracy 108 00:07:44,159 --> 00:07:45,850 with this new model. 109 00:07:45,850 --> 00:07:49,060 But hopefully, the penalty error will be much lower too. 110 00:07:53,850 --> 00:07:57,380 So now that our model is done, let's regenerate our test 111 00:07:57,380 --> 00:08:02,090 set predictions by scrolling up to where we created PredictTest 112 00:08:02,090 --> 00:08:06,330 and hitting Enter, and then recreating our classification 113 00:08:06,330 --> 00:08:10,200 matrix by scrolling up to the table function 114 00:08:10,200 --> 00:08:13,910 and hitting Enter again. 115 00:08:13,910 --> 00:08:23,250 Now let's add up the numbers on the diagonal, 94310 + 18942 116 00:08:23,250 --> 00:08:35,070 + 4692 + 636 + 2, and divide by the number of rows 117 00:08:35,070 --> 00:08:35,830 in ClaimsTest. 118 00:08:38,870 --> 00:08:40,390 And hit Enter. 119 00:08:40,390 --> 00:08:44,620 So the accuracy of this model is 0.647. 120 00:08:44,620 --> 00:08:49,080 And we can scroll up and compute the penalty error here 121 00:08:49,080 --> 00:08:53,570 by going back to the sum command and hitting Enter. 122 00:08:53,570 --> 00:08:57,890 So the penalty error of our new model is 0.642. 123 00:08:57,890 --> 00:09:01,170 Our accuracy is now lower than the baseline method, 124 00:09:01,170 --> 00:09:04,700 but our penalty error is also much lower. 125 00:09:04,700 --> 00:09:08,060 Note that we have significantly fewer independent variables 126 00:09:08,060 --> 00:09:09,970 than D2Hawkeye had. 127 00:09:09,970 --> 00:09:12,830 If we had the hundreds of codes and risk factors 128 00:09:12,830 --> 00:09:17,230 available to D2Hawkeye, we would hopefully do even better. 129 00:09:17,230 --> 00:09:20,550 In the next video, we'll discuss the accuracy of the models 130 00:09:20,550 --> 00:09:25,720 used by D2Hawkeye and how analytics can provide an edge.