Decision Trees¶
Overview¶
For each step, read the explanation, then run the code cell(s) right below it.
You will practice how to:
- Load and inspect data for a classification problem
- Visualize features to build intuition about possible splits
- Compute Gini and Entropy (impurity) for a candidate split
- Train and interpret a shallow decision tree vs. a fully grown tree
Import libraries¶
import os
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from dmba import plotDecisionTree
%matplotlib inline
# Set random seed variable for code reproducibility
SEED = 0
# Import local libraries
root_dir = Path.cwd().resolve().parents[0]
sys.path.append(str(root_dir))
# Visualization functions
from src.utils.helpers import *
# Load the "autoreload" extension so that code can change
%load_ext autoreload
#%reload_ext autoreload
# Always reload modules so that as you change code in src, it gets loaded
%autoreload 2
Riding Mowers Example¶
Create a dataframe for the RidingMowers.csv data
In the next cell, we load the dataset from a .csv file into a pandas DataFrame so we can explore it and model it.
This is the data for the Riding Mowers Example in the lecture.
mowers_df = pd.read_csv(os.path.join('..','data','RidingMowers.csv'))
mowers_df.head()
| Income | Lot_Size | Ownership | |
|---|---|---|---|
| 0 | 60.0 | 18.4 | Owner |
| 1 | 85.5 | 16.8 | Owner |
| 2 | 64.8 | 21.6 | Owner |
| 3 | 61.5 | 20.8 | Owner |
| 4 | 87.0 | 23.6 | Owner |
Create a scatterplot for Income and Lot Size with Ownership as the Color
Next, we visualize the relationship between Income and Lot Size and use color to show the class label (Ownership). This helps us see whether a simple split might separate the classes.
This is the scatter plot used in lecture to visualize the splits.
sns.scatterplot(x='Income',y='Lot_Size',hue='Ownership', data=mowers_df)
plt.show()
Calculate Gini Index for First Split Condition
Now we try a candidate split (a threshold on Income). We separate the data into the left/right child nodes and compute impurity for each side.
This is the code used for the first split using Gini index.
split_value = 59.7
split_condition = mowers_df['Income'] <= split_value
split_true = list(mowers_df['Ownership'][split_condition])
split_false = list(mowers_df['Ownership'][~split_condition])
print(f"Left Split: Income <= {split_value}, Gini Index = {gini_index(split_true)[1]:.3f}")
print(f"Right Split: Income > {split_value}, Gini Index = {gini_index(split_false)[1]:.3f}")
Left Split: Income <= 59.7, Gini Index = 0.219 Right Split: Income > 59.7, Gini Index = 0.430
Calculate Entropy for First Split Condition
Here we compute entropy for the left and right child nodes created by the Income split, and print the results.
This is the code used for the first split using Entropy measure.
print(f"Left Split: Income <= {split_value}, Entropy = {entropy_loss(split_true)[1]:.3f}")
print(f"Right Split: Income > {split_value}, Entropy = {entropy_loss(split_false)[1]:.3f}")
Left Split: Income <= 59.7, Entropy = 0.544 Right Split: Income > 59.7, Entropy = 0.896
Calculate Overall Gini and Entropy for First Split Condition
Finally, we compute the overall (weighted) impurity for this split by weighting each child node’s impurity by its share of samples.
This is the code used to calculated the combined impurity of the two nodes for both Gini and Entropy.
print(f"Combined Gini for Income Split on {split_value} is {weighted_impurity(split_true, split_false, mowers_df['Income'], 'gini'):.3f}")
print(f"Combined Entropy for Income Split on {split_value} is {weighted_impurity(split_true, split_false, mowers_df['Income'], 'entropy'):.3f}")
Combined Gini for Income Split on 59.7 is 0.359 Combined Entropy for Income Split on 59.7 is 0.779
Create Decision Tree with a Depth of 2
Use Gini for splitting criteria and display impurity value on nodes.
Next, we fit a small decision tree with max_depth=2 (a shallow, interpretable tree) using Gini as the split criterion, then visualize it.
This is the code used to create the tree after 3 splits.
mowers_X = mowers_df.drop(columns=['Ownership'])
mowers_y = mowers_df['Ownership']
dt = DecisionTreeClassifier(max_depth=2, criterion='gini', random_state=SEED)
dt.fit(mowers_X, mowers_y)
fig, ax = plt.subplots(figsize=(7, 5))
plot_tree(dt,
feature_names=mowers_X.columns,
class_names=['Nonowner', 'Owner'],
filled=True,
impurity=True,
ax=ax)
plt.tight_layout()
plt.show()
Create Full Decision Tree
Use Gini for splitting criteria and display impurity value on nodes.
Lastly, we fit a fully grown decision tree (no max depth) to see how the model continues splitting when not constrained.
This is the code used to create the tree after all splits.
dt2 = DecisionTreeClassifier(max_depth=None, criterion='gini', random_state=SEED)
dt2.fit(mowers_X, mowers_y)
fig, ax = plt.subplots(figsize=(7, 5))
plot_tree(dt2,
feature_names=mowers_X.columns,
class_names=['Nonowner', 'Owner'],
filled=True,
impurity=True,
ax=ax)
plt.tight_layout()
plt.show()
Key DecisionTreeClassifier Parameters¶
Below we will discuss some of the key parameters for the DecisionTreeClassifier that control how the tree grows, how complex it becomes, and how splits are chosen. Check out the API Reference for the full list of parameters.
criterion
Determines how the quality of a split is measured.
"gini"→ Uses Gini impurity (default)"entropy"→ Uses information gain"log_loss"→ Similar to entropy but based on logistic loss
This controls how "purity" is defined.
splitter
Strategy used to choose splits.
"best"→ Chooses the best split (default)"random"→ Chooses the best random split
random can add variability and sometimes reduce overfitting.
max_depth
Maximum depth of the tree.
None→ Grow until pure or minimum samples reached- Integer (e.g.,
max_depth=5)
Limits tree complexity where smaller depth = simpler model.
min_samples_split
Minimum samples required to split a node.
- Integer → exact number (e.g.,
10) - Float → fraction of dataset (e.g.,
0.05)
Larger value = fewer splits and can reduce overfitting.
min_samples_leaf
Minimum samples required in a leaf node.
Prevents leaves with very few observations.
max_leaf_nodes
Limits total number of leaf nodes.
Controls tree size directly.
max_features
Number of features considered at each split.
None→ All features"sqrt"→ √(#features)"log2"→ log₂(#features)- Integer / Float
Adds randomness and useful in ensembles such as Random Forest.
random_state
Ensures reproducibility.
Very important to always fix this value for consistent results.
Personal Loan Example¶
Create a dataframe for the UniversalBank.csv data
In the next cell, we load the dataset from a .csv file into a pandas DataFrame so we can explore it and model it.
This is the data for the Personal Loan Example in lecture.
bank_df = pd.read_csv(os.path.join('..','data','UniversalBank.csv'))
bank_df.head()
| ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 25 | 1 | 49 | 91107 | 4 | 1.6 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
| 1 | 2 | 45 | 19 | 34 | 90089 | 3 | 1.5 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
| 2 | 3 | 39 | 15 | 11 | 94720 | 1 | 1.0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 3 | 4 | 35 | 9 | 100 | 94112 | 1 | 2.7 | 2 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 5 | 35 | 8 | 45 | 91330 | 4 | 1.0 | 2 | 0 | 0 | 0 | 0 | 0 | 1 |
Create crosstab summaries of the Securities Account, CD Account, CreditCard variables.
We start by building contingency tables (crosstabs) between each binary predictor and the target (Personal Loan). This lets us see how informative each split might be.
This is the code used to create the tables for the calculations for the first split.
pd.crosstab(bank_df['Securities Account'], bank_df['Personal Loan'])
| Personal Loan | 0 | 1 |
|---|---|---|
| Securities Account | ||
| 0 | 4058 | 420 |
| 1 | 462 | 60 |
pd.crosstab(bank_df['CD Account'], bank_df['Personal Loan'])
| Personal Loan | 0 | 1 |
|---|---|---|
| CD Account | ||
| 0 | 4358 | 340 |
| 1 | 162 | 140 |
pd.crosstab(bank_df['CreditCard'], bank_df['Personal Loan'])
| Personal Loan | 0 | 1 |
|---|---|---|
| CreditCard | ||
| 0 | 3193 | 337 |
| 1 | 1327 | 143 |
Calculate overall impurity for Securities Account, CD Account, CreditCard variables.
Now we try a candidate split for each of the 3 binary variables. We separate the data into the left/right child nodes and compute impurity for each side.
This calculates the combined impurity for each of the 3 variables to evaluate which variable to use for the first split.
vs = ['Securities Account', 'CD Account', 'CreditCard']
for var in vs:
split_condition = bank_df[var] == 0
split_true = list(bank_df['Personal Loan'][split_condition])
split_false = list(bank_df['Personal Loan'][~split_condition])
sa = weighted_impurity(split_true, split_false, bank_df[var], 'gini')
print(f"Combined Gini for {var} Split is {sa:.4f}")
Combined Gini for Securities Account Split is 0.1735 Combined Gini for CD Account Split is 0.1562 Combined Gini for CreditCard Split is 0.1736
Create Decision Tree with a Depth of 2 with only the Securities Account, CD Account, CreditCard variables.
Use Gini for splitting criteria and display impurity value on nodes.
Next, we fit a small decision tree with max_depth=2 (a shallow, interpretable tree) using Gini as the split criterion, then visualize it.
This is the code used to create the final tree for the Personal Loan Example.
bank_X = bank_df[vs]
bank_y = bank_df['Personal Loan']
dt3 = DecisionTreeClassifier(max_depth=2, criterion='gini', random_state=SEED)
dt3.fit(bank_X, bank_y)
fig, ax = plt.subplots(figsize=(7, 5))
plot_tree(dt3,
feature_names=bank_X.columns,
class_names=['Declined', 'Accepted'],
filled=True,
impurity=True,
ax=ax)
plt.tight_layout()
plt.show()
Create a Text Version of Tree¶
The export_text() function converts the trained tree into a readable rule-based format.
Instead of a visual diagram, this shows:
- The sequence of splits
- The decision rules
- The predicted class at each leaf
- (Optional) Sample counts / weights
Typically in Jupyter Noteboook, you don't need to use the print() function but this is a good example where it preserves the formatting of the text.
print(export_text(dt3, feature_names=bank_X.columns, show_weights=True))
|--- CD Account <= 0.50 | |--- CreditCard <= 0.50 | | |--- weights: [3178.00, 290.00] class: 0 | |--- CreditCard > 0.50 | | |--- weights: [1180.00, 50.00] class: 0 |--- CD Account > 0.50 | |--- CreditCard <= 0.50 | | |--- weights: [15.00, 47.00] class: 1 | |--- CreditCard > 0.50 | | |--- weights: [147.00, 93.00] class: 0
Create a fully developed Decision Tree with all variables except for ID and Zip Code.
Use Gini for splitting criteria and use the text dmba plotDecisionTree() to display the tree.
Finally, we build a fuller model using all available predictors (excluding identifier fields) and plot the resulting tree.
This example was not shown in the lecture but included as a method to effectively visualize a very complex tree.
bank_X = bank_df.drop(columns=['ID', 'ZIP Code', 'Personal Loan'])
bank_y = bank_df['Personal Loan']
fullClassTree = DecisionTreeClassifier(random_state=SEED).fit(bank_X, bank_y)
plotDecisionTree(fullClassTree, feature_names=bank_X.columns)