Building a Decision Tree from Scratch: A Python Odyssey
A decision tree is a flowchart-like structure where each internal node represents a "test" on an attribute (e.g., whether a coin flip comes up heads or tails), each branch represents the outcome of the test, and each leaf node represents a class label (decision taken after computing all attributes). The paths from root to leaf represent classification rules.
Here are the basic steps involved in the decision tree algorithm:
1. Selection of the attribute to split:
Choose an attribute from the dataset.
Calculate the significance of the attribute in splitting the data.
The attribute with the highest significance (e.g., information gain or Gini index) is selected to make a decision at the node.
2. Splitting:
Divide the dataset into subsets that correspond to different values of the selected attribute.
3. Recursion:
For each subset, repeat steps 1 and 2 until one of the stopping conditions is met, such as a maximum depth, a minimum node size, or if the subset at a node has the same value of the target variable.
4. Terminal Node Assignment:
Once the stopping conditions are met, the subset of data at that node is assigned to a class label, turning it into a terminal/leaf node.
Decision Tree, a versatile algorithm used for both classification and regression tasks. In this blog, we will demystify the construction of decision trees, specifically using the Gini index, and illustrate it with a simple Python code along with a toy example.
Understanding the Terminology
Before diving into the code, let’s clarify some terminologies related to decision trees:
Nodes: Points where the tree splits based on a certain condition.
Root Node: The node at which the tree starts, encompassing all the data points.
Leaf Nodes (or Terminal Nodes): Nodes at which the tree stops growing, representing the final predictions.
Depth: The length of the longest path from the root node to a leaf node.
Gini Index: The Splitting Criterion
Decision trees employ various metrics to decide the optimal split at each node. One common metric is the Gini index, which measures the impurity of the data. The formula for the Gini index is:
where pi is the proportion of samples belonging to class i in the dataset D.
The Gini Index for a split is calculated as the weighted sum of the Gini Index of each subset created by the split:
where |D| is the total number of samples, |Dv| is the number of samples in subset v, and the sum is over all subsets created by the split.
Implementing Decision Tree in Python
Now, let's write a simple Python code to build a decision tree. We will use a toy dataset to illustrate the process:
import numpy as np
# Toy dataset
data = np.array([
[1, 1, 'Yes'],
[1, 0, 'No'],
[0, 1, 'No'],
[0, 0, 'No']
])
features = ['Feature1', 'Feature2']
# Function to calculate Gini index
def gini_index(dataset):
m = len(dataset)
if m <= 1:
return 0
class_counts = np.unique(dataset[:, -1], return_counts=True)[1]
p = class_counts / m
gini = 1 - np.sum(p**2)
return gini
# Recursive function to build the tree
def build_tree(data, features, depth=0, max_depth=3):
# Base cases: pure node or max depth reached
unique_classes = np.unique(data[:, -1])
if len(unique_classes) == 1 or depth == max_depth:
return {'class': unique_classes[0]}
# Finding the best split
m, n = data.shape
if m <= 1:
return {'class': data[0, -1]}
num_features = n - 1
gini_before_split = gini_index(data)
gini_split = np.inf # Initialize to infinity
for f in range(num_features):
values = np.unique(data[:, f])
for value in values:
data_left = data[data[:, f] == value]
data_right = data[data[:, f] != value]
gini_left = gini_index(data_left)
gini_right = gini_index(data_right)
gini_current_split = (len(data_left)*gini_left + len(data_right)*gini_right) / m
if gini_current_split < gini_split:
gini_split = gini_current_split
best_feature = features[f]
best_value = value
# Build subtrees
data_left = data[data[:, best_feature] == best_value]
data_right = data[data[:, best_feature] != best_value]
subtree = {
'feature': best_feature,
'value': best_value,
'left': build_tree(data_left, features, depth + 1, max_depth),
'right': build_tree(data_right, features, depth + 1, max_depth)
}
return subtree
# Building the tree
tree = build_tree(data, features)
In the above code, we first define a function gini_index to calculate the Gini index of a given dataset. We then define a recursive function build_tree to build the decision tree. This function takes the data, list of feature names, current depth, and maximum depth as arguments. It finds the best split based on the Gini index, and recursively builds the left and right subtrees.
Our toy dataset comprises four data points with two features. By running the build_tree function on this dataset, we will obtain a simple decision tree that demonstrates the essence of decision tree learning.
Wrapping Up
Constructing a decision tree from scratch is a fantastic way to grasp the inner workings of this fundamental machine learning algorithm. By understanding the core concepts and coding them out, you are well on your way to mastering decision tree-based models like Random Forests and Gradient Boosted Trees. Happy coding!