Let’s Write a Decision Tree Classifier from Scratch – Machine Learning Recipes #8

Let’s Write a Decision Tree Classifier from Scratch – Machine Learning Recipes #8

JOSH GORDON: Hey, everyone. Welcome back. In this episode, we’ll write
a decision tree classifier from scratch in pure Python. Here’s an outline
of what we’ll cover. I’ll start by introducing
the data set we’ll work with. Next, we’ll preview
the completed tree. And then, we’ll build it. On the way, we’ll cover concepts
like decision tree learning, Gini impurity, and
information gain. And you can find the
code for this episode in the description. And it’s available
in two formats, both as a Jupiter notebook
and as a regular Python file. OK, let’s get started. For this episode, I’ve
written a toy data set that includes both numeric
and categorical attributes. And here, our goal will be
to predict the type of fruit, like an apple or a
grape, based on features like color and size. At the end of the
episode, I encourage you to swap out this data
set for one of your own and build a tree for a
problem you care about. Let’s look at the format. I’ve re-drawn it
here for clarity. Each row is an example. And the first two columns
provide features or attributes that describe the data. The last column gives
the label, or the class, we want to predict. And if you like, you
can modify this data set by adding additional
features or more examples, and our program will work
in exactly the same way. Now, this data set is
pretty straightforward, except for one thing. I’ve written it so it’s
not perfectly separable. And by that I mean there’s
no way to tell apart the second and fifth examples. They have the same features,
but different labels. And this is so we can see how
our tree handles this case. Towards the end of
the notebook, you’ll find testing data
in the same format. Now I’ve written a few utility
functions that make it easier to work with this data. And below each function,
I’ve written a small demo to show how it works. And I’ve repeated this pattern
for every block of code in the notebook. Now to build the tree, we use
the decision tree learning algorithm called CART. And as it happens, there’s
a whole family of algorithms used to build trees from data. At their core, they
give you a procedure to decide which questions
to ask and when. CART stands for Classification
and Regression Trees. And here’s a preview
of how it works. To begin, we’ll add a
root node for the tree. And all nodes receive a
list of rows as input. And the root will receive
the entire training set. Now each node will ask
a true false question about one of the features. And in response
to this question, we split, or partition,
the data into two subsets. These subsets then become
the input to two child nodes we add to the tree. And the goal of the question
is to unmix the labels as we proceed down. Or in other words, to
produce the purest possible distribution of the
labels at each node. For example, the
input to this node contains only a
single type of label, so we’d say it’s
perfectly unmixed. There’s no uncertainty
about the type of label. On the other hand, the labels
in this node are still mixed up, so we’d ask another question
to further narrow it down. And the trick to building
an effective tree is to understand which
questions to ask and when. And to do that, we need to
quantify how much a question helps to unmix the labels. And we can quantify the
amount of uncertainty at a single node using a
metric called Gini impurity. And we can quantify
how much a question reduces that uncertainty
using a concept called information gain. We’ll use these
to select the best question to ask at each point. And given that question, we’ll
recursively build the tree on each of the new nodes. We’ll continue dividing
the data until there are no further questions
to ask, at which point we’ll add a leaf. To implement this, first
we need to understand what type of questions
can we ask about the data. And second, we
need to understand how to decide which
question to ask when. Now each node takes a
list of rows as input. And to generate a
list of questions we’ll iterate over every
value for every feature that appears in those rows. Each of these
becomes a candidate for a threshold we can
use to partition the data. And there will often
be many possibilities. In code we represent
a question by storing a column number
and a column value, or the threshold we’ll
use to partition the data. For example, here’s how
we’d write a question to test if the color is green. And here’s an example
for a numeric attribute to test if the diameter is
greater than or equal to 3. In response to a question, we
divide, or partition, the data into two subsets. The first contains all the rows
for which the question is true. And the second contains
everything else. In code, our partition
function takes a question and a list of rows as input. For example, here’s how we
partition the rows based on whether the color is red. Here, true rows contains
all the red examples. And false rows contains
everything else. The best question is the one
that reduces our uncertainty the most. And Gini impurity let’s us
quantify how much uncertainty there is at a node. Information gain will
let us quantify how much a question reduces that. Let’s work on impurity first. Now this is a metric that
ranges between 0 and 1 where lower values indicate
less uncertainty, or mixing, at a node. It quantifies our chance of
being incorrect if we randomly assign a label from a set
to an example in that set. Here’s an example
to make that clear. Imagine we have two bowls
and one contains the examples and the other contains labels. First, we’ll randomly draw an
example from the first bowl. Then we’ll randomly draw
a label from the second. And now, we’ll classify the
example as having that label. And Gini impurity gives us
our chance of being incorrect. In this example, we have
only apples in each bowl. There’s no way to
make a mistake. So we say the impurity is zero. On the other hand, given a
bowl with five different types of fruit in equal
proportion, we’d say it has an impurity of 0.8. That’s because we have a one out
of five chance of being right if we randomly assign
a label to an example. In code, this method calculates
the impurity of a data set. And I’ve written
a couple examples below that demonstrate
how it works. You can see the impurity
for the first set is zero because there’s no mixing. And here, you can see
the impurity is 0.8. Now information gain will let us
find the question that reduces our uncertainty the most. And it’s just a
number that describes how much a question helps to
unmix the labels at a node. Here’s the idea. We begin by calculating
the uncertainty of our starting set. Then, for each
question we can ask, we’ll try partitioning
the data and calculating the uncertainty of the
child nodes that result. We’ll take a weighted
average of their uncertainty because we care more about a
large set with low uncertainty than a small set with high. Then, we’ll subtract this
from our starting uncertainty. And that’s our information gain. As we go, we’ll keep
track of the question that produces the most gain. And that will be the best
one to ask at this node. Let’s see how this
looks in code. Here, we’ll iterate over
every value for the features. We’ll generate a question
for that feature, then partition the data on it. Notice we discard any questions
that fail to produce a split. Then, we’ll calculate
our information gain. And inside this
function, you can see we take a weighted average
and the impurity of each set. We see how much this
reduces the uncertainty from our starting set. And we keep track
of the best value. I’ve written a couple
of demos below as well. OK, with these concepts in hand,
we’re ready to build the tree. And to put this all together I
think the most useful thing I can do is walk you
through the algorithm as it builds a tree
for our training data. This uses recursion, so seeing
it in action can be helpful. You can find the code for this
inside the Build Tree function. When we call build tree
for the first time, it receives the entire
training set as input. And as output it will
return a reference to the root node of our tree. I’ll draw a placeholder
for the root here in gray. And here are the rows we’re
considering at this node. And to start, that’s
the entire training set. Now we find the best
question to ask at this node. And we do that by iterating
over each of these values. We’ll split the data and
calculate the information gained for each one. And as we go, we’ll keep
track of the question that produces the most gain. Now in this case, there’s
a useful question to ask, so the gain will be
greater than zero. And we’ll split the data
using that question. And now, we’ll use recursion
by calling build tree again to add a node for
the true branch. The rows we’re considering
now are the first half of the split. And again, we’ll find the best
question to ask for this data. Once more we split and call
the build tree function to add the child node. Now for this data there are
no further questions to ask. So the information
gain will be zero. And this node becomes a leaf. It will predict that
an example is either an apple or a lemon
with 50% confidence because that’s the ratio
of the labels in the data. Now we’ll continue by
building the false branch. And here, this will
also become a leaf. We’ll predict apple
with 100% confidence. Now the previous call
returns, and this node becomes a decision node. In code, that just means
it holds a reference to the question we asked and
the two child nodes that result. And we’re nearly done. Now we return to the root node
and build the false branch. There are no further questions
to ask, so this becomes a leaf. And that predicts grape
with 100% confidence. And finally, the root node
also becomes a decision node. And our call to build tree
returns a reference to it. If you scroll down
in the code, you’ll see that I’ve added functions
to classify data and print the tree. And these start with a
reference to the root node, so you can see how it works. OK, hope that was helpful. And you can check out the
code for more details. There’s a lot more I have
to say about decision trees, but there’s only so much we
can fit into a short time. Here are a couple of topics
that are good to be aware of. And you can check out the
books in the description to learn more. As a next step, I’d
recommend modifying the tree to work with your own data set. And this can be a
fun way to build a simple and interpretable
classifier for use in your projects. Thanks for watching, everyone. And I’ll see you next time.

100 thoughts on “Let’s Write a Decision Tree Classifier from Scratch – Machine Learning Recipes #8”

  1. Why Impurity is calculated one way on 5:33 and on the code it's calculated differently? (1-(times the # of possible labels) vs 1-(# of possible labels)**2)?

  2. Took nearly one year for this..how abt next video Josh ?? Should we wait till 2018??. Very good contents..thank you very much..

  3. Typo in line 29 of decision_tree.ipynb:
    best_question = None # keep train of the feature / value that produced it
    train = track *

  4. Started to watch the series 2 days ago, you are explaining SO well. Many thanks!

    More videos on additional types of problems we can solve with Machine Learning would be very helpful. Few ideas: traveling salesman problem, generating photos while emulating analog artefacts or simple ranking of new dishes I would like to try based on my restaurants' order history. Even answering with the relevant links/terminology would be fantastic.

    Also, would be great to know what problems are still hard to solve or should not be solved via Machine Learning 🙂

  5. Hey, I am testing your classifier on my own dataset by chaning header and seting training_data accordingly:
    header = ["location","w","final_margin","shot_number","period","game_clock","shot_clock","dribbles","touch_time",
    df = pd.read_csv('data/basketball.train.csv', header=None, names=header, na_values="?")


    while i<len(obj_df):

    training_data has the same format as in your code, but still the program keeps hanging up on the line, val = example[self.column] in the method match(self, example). Any idea why? (I also changed out test_data, but the program stops before starting the predictions).

  6. I have a doubt, At 6:20, how the impurity becomes 0.64 ? also the impurity for false condition giving 0.62. Please help me.

  7. Thanks a lot, Josh. To a very basic beginner, every sentence you say is a gem. It took me half hour to get the full meaning of the first 4 mins of the video, as I was taking notes and repeating it to myself to grasp everything that was being said.

    The reason I wanted to showcase my slow pace is to say how important and understandable I felt in regard to every sentence.

    And, it wasn't boring at all.

    Great job, and please, keep em coming.

  8. Making classifiers is easy and fun. But what if I want a maching that learns how to play board games? What can we do?

  9. You have no idea how your videos helped me out on my journey on Machine Learning. thanks a lot Josh you are awesome.

  10. You have saved weeks amount of work. So short yet so deep. Guys first try to understand the code then watch the video.

  11. Awesome video, helped me lot…. Was struggling to understand these exact stuffs…..Looking forward to the continuing courses.

  12. I don't get one thing here. How do we determine the number for the question. Like I understand that we try out different features to see which gives us the most info but how do we choose the number and condition for it?

  13. Help! I cant print my_tree using the print_tree() function…. It only show this: Predict {'Good': 47, 'Bad': 150, 'Average': 89}.Please help…

  14. How to choose k value in knn value.
    Based on accuracy or square root of length of test data.
    Can anyone help me.

  15. I think you might confusing Information Gain and Gini Index. Information gain is reduce of entropy, not reduce of gini impurity. I almost did a mistake in my Engineering paper because of this video. But I luckily noticed different definition of information gain in a different source. Maybe it's just thing of naming but it can mislead people who are new in this subject :/

  16. This is the best single resource on decision trees that I've found, and it's a topic that isn't covered enough considering that random forests are a very powerful and easy tool to implement. If only they released more tutorials!

  17. I have a follow up question. How did we come up with the questions. As in..how did we know we would like to ask if the diameter is > 3, why not ask if diameter > 2?

  18. This is the best tutorial on the net but this uses CART. I was really hoping to use C5.0 but unfortunately the package is only available in R. I used rpy2 to call the C50 function in Python. It would be great if there'd be a tutorial on that.

  19. Even it took me more than 30 minutes to complete & understand the video. I can not tell you how this explanation is amazing !

    This is how we calculate the impurity !
    PS: G(k) = Σ P(i) * (1 – P(i))
    i = (Apple, Grape,Lemon)
    2/5 * (1- 2/5) + 2/5 * (1- 2/5) + 1/5 *(1-1/5)=
    0.4 * (0.6) + 0.4 * (0.6) + 0.2 * (0.8)=
    0.24 + 0.24 + 0.16 = 0.64

  20. amazing video!!! Thank you so much for the great lecture and showing the python code to make us understand the algorithm better!

  21. How come at 6:20 he calls it average but doesn't divide it by 2? Also the same thing in a stack overflow question it seems to be called entropy after. Is this correct?

  22. Question about calculating impurity. If we do probability, we first draw data which give us probability of 0.2 then we draw label which give us another 0.2. Shouldn't the impurity be 1 – 0.2*0.2=0.96?

  23. Could you make a similar video on fuzzy decision tree classifiers or share a good source for studying and implementing them?

  24. It is easy to find best split if data is categorical. How do split happens in a time optimized way if variable is continuous unlike color or just 2 values of diameter? Should I just run through min to max values? Can median be used here? Please suggest!!

Leave a Reply

Your email address will not be published. Required fields are marked *