QBoard » Artificial Intelligence & ML » AI and ML - Python » How to extract the decision rules from scikit-learn decision-tree?

How to extract the decision rules from scikit-learn decision-tree?

  • Can I extract the underlying decision-rules (or 'decision paths') from a trained tree in a decision tree as a textual list?

    Something like:

    if A>0.4 then if B<0.2 then if C>0.8 then class='X'
      September 14, 2020 3:21 PM IST
    0
  • Scikit learn introduced a delicious new method called export_text in version 0.21 to extract the rules from a tree. Documentation here. It's no longer necessary to create a custom function.

    Once you've fit your model, you just need two lines of code. First, import export_text:

    from sklearn.tree.export import export_text

    Second, create an object that will contain your rules. To make the rules look more readable, use the feature_names argument and pass a list of your feature names. For example, if your model is called model and your features are named in a dataframe called X_train, you could create an object called tree_rules:

    tree_rules = export_text(model, feature_names=list(X_train))

    Then just print or save tree_rules. Your output will look like this:

    |--- Age <= 0.63
    |   |--- EstimatedSalary <= 0.61
    |   |   |--- Age <= -0.16
    |   |   |   |--- class: 0
    |   |   |--- Age >  -0.16
    |   |   |   |--- EstimatedSalary <= -0.06
    |   |   |   |   |--- class: 0
    |   |   |   |--- EstimatedSalary >  -0.06
    |   |   |   |   |--- EstimatedSalary <= 0.40
    |   |   |   |   |   |--- EstimatedSalary <= 0.03
    |   |   |   |   |   |   |--- class: 1
     
    This post was edited by Jasmine Chacko at September 27, 2020 2:26 AM IST
      September 14, 2020 5:22 PM IST
    1
  • There is a new DecisionTreeClassifier method,decision_path , in the 0.18.0 release. The developers provide an extensive (well-documented) walkthrough.

    ample_id = 0
    node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                        node_indicator.indptr[sample_id + 1]]
    
    print('Rules used to predict sample %s: ' % sample_id)
    for node_id in node_index:
    
        if leave_id[sample_id] == node_id:  # <-- changed != to ==
            #continue # <-- comment out
            print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--
    
        else: # < -- added else to iterate through decision nodes
            if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
                threshold_sign = "<="
            else:
                threshold_sign = ">"
    
            print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
                  % (node_id,
                     sample_id,
                     feature[node_id],
                     X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                     threshold_sign,
                     threshold[node_id]))
    
    Rules used to predict sample 0: 
    decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
    decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
    leaf node 4 reached, no decision here
    ​

    Change the  sample_id to see the decision paths for other samples. I haven't asked the developers about these changes, just seemed more intuitive when working through the example.
      September 14, 2020 4:26 PM IST
    0
  • def get_code(tree, feature_names):
            left      = tree.tree_.children_left
            right     = tree.tree_.children_right
            threshold = tree.tree_.threshold
            features  = [feature_names for i in tree.tree_.feature]
            value = tree.tree_.value
    
            def recurse(left, right, threshold, features, node):
                    if (threshold[node] != -2):
                            print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                            if left[node] != -1:
                                    recurse (left, right, threshold, features,left[node])
                            print "} else {"
                            if right[node] != -1:
                                    recurse (left, right, threshold, features,right[node])
                            print "}"
                    else:
                            print "return " + str(value[node])
    
            recurse(left, right, threshold, features, 0)

    if you call get_code(dt, df.columns) on the same example you will obtain:

    if ( col1 <= 0.5 ) {
    return [[ 1.  0.]]
    } else {
    if ( col2 <= 4.5 ) {
    return [[ 0.  1.]]
    } else {
    if ( col1 <= 2.5 ) {
    return [[ 1.  0.]]
    } else {
    return [[ 0.  1.]]
    }
    }
    }
      September 14, 2020 4:28 PM IST
    0
    • Shivakumar Kota
      Shivakumar Kota @Pranav B,Can you tell , what exactly [[ 1. 0.]] in the return statement means in the above output . So it will be good for me if you please prove some details so that it will be easier for me
      September 14, 2020
    • Pranav B
      Pranav B @Shivakumar Kota, It means that there is one object in the class '0' and zero objects in the class '1'
      September 14, 2020
  • I have modified the top liked code to indent in a jupyter notebook python 3 correctly

    import numpy as np
    from sklearn.tree import _tree
    
    def tree_to_code(tree, feature_names):
        tree_ = tree.tree_
        feature_name = [feature_names 
                        if i != _tree.TREE_UNDEFINED else "undefined!" 
                        for i in tree_.feature]
        print("def tree({}):".format(", ".join(feature_names)))
    
        def recurse(node, depth):
            indent = "    " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                print("{}if {} <= {}:".format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                print("{}else:  # if {} > {}".format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                print("{}return {}".format(indent, np.argmax(tree_.value[node])))
    
        recurse(0, 1)
     
      September 14, 2020 5:40 PM IST
    0