Keywords: Scikit-learn | Decision Tree | Rule Extraction
Abstract: This article provides an in-depth exploration of methods for extracting human-readable decision rules from Scikit-learn decision tree models. Focusing on the best-practice approach, it details the technical implementation using the tree.tree_ internal data structure with recursive traversal, while comparing the advantages and disadvantages of alternative methods. Complete Python code examples are included, explaining how to avoid common pitfalls such as incorrect leaf node identification and handling feature indices of -2. The official export_text method introduced in Scikit-learn 0.21 is also briefly discussed as a supplementary reference.
Introduction
Decision tree models are highly valued in machine learning for their interpretability. However, the Scikit-learn library does not natively provide functionality to extract decision rules in a textual format. This article, based on high-scoring answers from Stack Overflow, thoroughly examines how to extract decision paths from trained decision trees and present them in a readable text form.
Internal Data Structure of Decision Trees
Scikit-learn's decision tree implementation uses Cython optimization, with its internal structure exposed through the tree.tree_ object. Key attributes include:
feature: Split feature index for each node, with leaf nodes marked as_tree.TREE_UNDEFINED(typically -2)threshold: Split thresholdchildren_leftandchildren_right: Indices of left and right child nodesvalue: Node value (class probabilities or regression values)
Core Extraction Method
The following function implements the conversion from a decision tree to Python code, representing the most robust approach currently available:
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)
Method Analysis
The key strengths of this function are:
- Correct Leaf Node Identification: By checking
tree_.feature[node] != _tree.TREE_UNDEFINEDinstead of whether the threshold is -2, it avoids misclassification when actual thresholds equal -2. - Safe Feature Name Mapping: Uses list comprehension to handle feature indices of -2, preventing index errors.
- Clear Code Structure: Generates well-formatted Python functions with complete if-else logic.
Example Output
For a simple regression tree, the function might output:
def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]
Comparison with Alternative Methods
Method Two: Node Lineage Tracking
This method traces paths from leaf nodes back to the root, generating decision rule tuples:
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
for child in idx:
for node in recurse(left, right, child):
print node
The output format is a sequence of (node ID, branch direction, threshold, feature name), suitable for conversion to other rule formats.
Method Three: Official export_text Method
Scikit-learn version 0.21 introduced the export_text function:
from sklearn.tree import export_text
tree_rules = export_text(model, feature_names=list(X_train.columns))
This outputs a tree-shaped text representation but offers less customization flexibility.
Method Four: Simplified Recursion
This is a simplified version of Method One but has potential issues with threshold checking:
def get_code(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node):
if (threshold[node] != -2):
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse(left, right, threshold, features, left[node])
print "} else {"
if right[node] != -1:
recurse(left, right, threshold, features, right[node])
print "}"
else:
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
Practical Recommendations
1. Prefer Method One: It is the most robust, correctly handles edge cases, and generates executable code.
2. Feature Name Handling: Ensure the feature_names list matches the order of features in the training data.
3. Multi-class Problems: For classification trees, tree_.value[node] returns an array of sample counts per class, requiring further processing to obtain predicted classes.
4. Random Forests: For random forests, iterate through each tree and combine rules via voting or averaging.
Conclusion
Extracting decision tree rules is crucial for model interpretation. Method One, detailed in this article, leverages Scikit-learn's internal data structure through recursive traversal to generate readable Python code while avoiding common pitfalls. Although Scikit-learn provides official methods like export_text, custom functions offer greater flexibility and control. Understanding these technical details enhances the ability to interpret and deploy decision tree models effectively.