Extracting Decision Rules from Scikit-learn Decision Trees: A Comprehensive Guide

Dec 04, 2025 · Programming · 7 views · 7.8

Keywords: Scikit-learn | Decision Tree | Rule Extraction

Abstract: This article provides an in-depth exploration of methods for extracting human-readable decision rules from Scikit-learn decision tree models. Focusing on the best-practice approach, it details the technical implementation using the tree.tree_ internal data structure with recursive traversal, while comparing the advantages and disadvantages of alternative methods. Complete Python code examples are included, explaining how to avoid common pitfalls such as incorrect leaf node identification and handling feature indices of -2. The official export_text method introduced in Scikit-learn 0.21 is also briefly discussed as a supplementary reference.

Introduction

Decision tree models are highly valued in machine learning for their interpretability. However, the Scikit-learn library does not natively provide functionality to extract decision rules in a textual format. This article, based on high-scoring answers from Stack Overflow, thoroughly examines how to extract decision paths from trained decision trees and present them in a readable text form.

Internal Data Structure of Decision Trees

Scikit-learn's decision tree implementation uses Cython optimization, with its internal structure exposed through the tree.tree_ object. Key attributes include:

Core Extraction Method

The following function implements the conversion from a decision tree to Python code, representing the most robust approach currently available:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Method Analysis

The key strengths of this function are:

  1. Correct Leaf Node Identification: By checking tree_.feature[node] != _tree.TREE_UNDEFINED instead of whether the threshold is -2, it avoids misclassification when actual thresholds equal -2.
  2. Safe Feature Name Mapping: Uses list comprehension to handle feature indices of -2, preventing index errors.
  3. Clear Code Structure: Generates well-formatted Python functions with complete if-else logic.

Example Output

For a simple regression tree, the function might output:

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Comparison with Alternative Methods

Method Two: Node Lineage Tracking

This method traces paths from leaf nodes back to the root, generating decision rule tuples:

def get_lineage(tree, feature_names):
    left = tree.tree_.children_left
    right = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features = [feature_names[i] for i in tree.tree_.feature]
    idx = np.argwhere(left == -1)[:,0]
    
    def recurse(left, right, child, lineage=None):
        if lineage is None:
            lineage = [child]
        if child in left:
            parent = np.where(left == child)[0].item()
            split = 'l'
        else:
            parent = np.where(right == child)[0].item()
            split = 'r'
        lineage.append((parent, split, threshold[parent], features[parent]))
        if parent == 0:
            lineage.reverse()
            return lineage
        else:
            return recurse(left, right, parent, lineage)
    
    for child in idx:
        for node in recurse(left, right, child):
            print node

The output format is a sequence of (node ID, branch direction, threshold, feature name), suitable for conversion to other rule formats.

Method Three: Official export_text Method

Scikit-learn version 0.21 introduced the export_text function:

from sklearn.tree import export_text
tree_rules = export_text(model, feature_names=list(X_train.columns))

This outputs a tree-shaped text representation but offers less customization flexibility.

Method Four: Simplified Recursion

This is a simplified version of Method One but has potential issues with threshold checking:

def get_code(tree, feature_names):
    left = tree.tree_.children_left
    right = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value
    
    def recurse(left, right, threshold, features, node):
        if (threshold[node] != -2):
            print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
            if left[node] != -1:
                recurse(left, right, threshold, features, left[node])
            print "} else {"
            if right[node] != -1:
                recurse(left, right, threshold, features, right[node])
            print "}"
        else:
            print "return " + str(value[node])
    
    recurse(left, right, threshold, features, 0)

Practical Recommendations

1. Prefer Method One: It is the most robust, correctly handles edge cases, and generates executable code.

2. Feature Name Handling: Ensure the feature_names list matches the order of features in the training data.

3. Multi-class Problems: For classification trees, tree_.value[node] returns an array of sample counts per class, requiring further processing to obtain predicted classes.

4. Random Forests: For random forests, iterate through each tree and combine rules via voting or averaging.

Conclusion

Extracting decision tree rules is crucial for model interpretation. Method One, detailed in this article, leverages Scikit-learn's internal data structure through recursive traversal to generate readable Python code while avoiding common pitfalls. Although Scikit-learn provides official methods like export_text, custom functions offer greater flexibility and control. Understanding these technical details enhances the ability to interpret and deploy decision tree models effectively.

Copyright Notice: All rights in this article are reserved by the operators of DevGex. Reasonable sharing and citation are welcome; any reproduction, excerpting, or re-publication without prior permission is prohibited.