Giovanni Bricconi

My site on WordPress.com

Deep Neural Decision Trees (DNDT)

leave a comment »

Following the “machine learning algorithms in 7 days” training there was a reference to this paper:

Deep Neural Decision Trees, Yongxin Yang – Irene Garcia Morillo – Timothy M. Hospedales, 2018 ICML Workshop on Human Interpretability in Machine Learning (WHI 2018), Stockholm, Sweden.

https://arxiv.org/abs/1806.06988

Decision Trees and Neural Networks are well different model, so I was curious to understand how it was possible to do a neural network that behaves as a decision tree! With some tricks it is possible to do that, but why would you like to do such thing? The authors propose these motivations:

  • Neural networks can be trained efficiently, using GPUs, but they are far from being interpretable. Making them work like decision tree results in set of rules on explanatory variables, that can be checked and understood by an human. In some domains this is important to make the model accepted.
  • The neural network is capable of learning which explanatory variables are important. In the built model, it is possible to identify useless variables and remove them. In a sense the neural network learning performs a decision tree pruning.
  • Implement a DNDT model in a neural network framework requires few lines of code

The DNDT model is implemented in this way: the tree will be composed of splits, corresponding to a single explanatory variable: this is a limitation, but makes very easy to interpret the results! The splits here are not binary, but n-ary: for instance if you have the variable “petal length” you can decide b_1,b_2,…,b_n coefficients that represent the decision boundaries. If the variable is less than b_1 you will follow one path in the tree, between b_1 and b_2 the second path, and so on. These boundaries will be decided by the Neural Network training algorithm, for instance the SGD.

The number of coefficients will be the same for all the explanatory variable, it is a model meta-parameter. This may seem a limitation, but actually it is not: the SGD algorithm is free to learn for instance that a variable is not important, and set all the coefficients so that just one path is followed. This is the reason why the DNDT is able to “prune” the tree: if you see that, on all the training samples, only one path is followed you know that that variable is useless.

To train the neural network with SGD this split function must be differentiable: the authors propose this soft binning function:

softmax((wx+b)/tau)

softmax function

Where w is a constant vector containing [1,2,…,nSplits+1] and b is a vector containing the split coefficients to be learned. Tau is a parameter that can be tuned, the closer it is to zero, the more the output looks like an one-hot encoding. The following one is a visual example, with tau=0.01 and 3 splits:

Part of figure 1 in the paper

So far we have a function that can be applied to a single scalar explanatory variable: each input variable is connected to a unit with the soft binning function, which has n outputs, one for each decision tree node branch. All these output will be simply connected to a second layer of perceptrons, so that you map app the output combinations: the authors refers to this as a Kronecker product. So suppose you just have 2 input variables a and b, and you have n=3 splits; the two input units will have 3 output each: o_a1, o_a2, o_a3, o_b1, o_b2, o_b3. In the second layers the perceptrons will have these input pairs: [o_a1,o_b1],[o_a1,o_b2],[o_a1,o_b3],[o_a2,o_b1],[o_a2,o_b2],…To lean more about the Kronecker product you can have a look at Wikipedia.

The second layer outputs are then simply connected to a third-layer to implement the soft-max classifier. Reading the paper you will se a classifier example for the Iris data set: the input variables are petal length and petal width and the output categories are Setosa, Versicolor and Virginica.

Clearly all these all-to-all connections make the model simple but not able to scale to a large number of input variables. The authors propose as solution training a forest of decision trees, based on fewer variables, but this goes against the model interpretability.

To concluded let me cite the authors:

We introduced a neural network based tree model DNDT.
It has better performance than NNs for certain tabular
datasets, while providing an interpretable decision tree.
Meanwhile compared to conventional DTs, DNDT is simpler
to implement, simultaneously searches tree structure
and parameters with SGD, and is easily GPU accelerated.

Yang, Morillo, Hospedales

Written by Giovanni

January 15, 2023 at 2:53 pm

Posted in Varie

Leave a comment