<!DOCTYPE article PUBLIC "-//NLM//DTD JATS (Z39.96) Journal Archiving and Interchange DTD v1.0 20120330//EN" "JATS-archivearticle1.dtd">
<article xmlns:xlink="http://www.w3.org/1999/xlink">
  <front>
    <journal-meta />
    <article-meta>
      <title-group>
        <article-title>Distilling a Neural Network Into a Soft Decision Tree</article-title>
      </title-group>
      <contrib-group>
        <contrib contrib-type="author">
          <string-name>Nicholas Frosst</string-name>
        </contrib>
        <contrib contrib-type="author">
          <string-name>Geoffrey Hinton</string-name>
        </contrib>
      </contrib-group>
      <abstract>
        <p>Deep neural networks have proved to be a very effective way to perform classification tasks. They excel when the input data is high dimensional, the relationship between the input and the output is complicated, and the number of labeled training examples is large [Szegedy et al., 2015, Wu et al., 2016, Jozefowicz et al., 2016, Graves et al., 2013]. But it is hard to explain why a learned network makes a particular classification decision on a particular test case. This is due to their reliance on distributed hierarchical representations. If we could take the knowledge acquired by the neural net and express the same knowledge in a model that relies on hierarchical decisions instead, explaining a particular decision would be much easier. We describe a way of using a trained neural net to create a type of soft decision tree that generalizes better than one learned directly from the training data.</p>
      </abstract>
    </article-meta>
  </front>
  <body>
    <sec id="sec-1">
      <title>Introduction</title>
      <p>The excellent generalization abilities of deep neural nets depend on their use
of distributed representations [LeCun et al., 2015] in their hidden layers, but
these representations are hard to understand. For the first hidden layer we can
understand what causes an activation of a unit and for the last hidden layer we
can understand the effects of activating a unit, but for the other hidden layers it
is much harder to understand the causes and effects of a feature activation in
terms of variables that are meaningful such as the input and output variables.
Also, the units in a hidden layer factor the representation of the input vector
into a set of feature activations in such a way that the combined effects of the
active features can cause an appropriate distributed representation in the next
hidden layer. This makes it very difficult to understand the functional role of any
particular feature activation in isolation since its marginal effect depends on the
effects of all the other units in the same layer.</p>
      <p>These difficulties are further compounded by the fact that deep neural nets
can make reliable decisions by modeling a very large number of weak statistical
regularities in the relationship between the inputs and outputs of the training data
and there is nothing in the neural network to distinguish the weak regularities
that are true properties of the data from the spurious regularities that are created
by the sampling peculiarities of the training set. Faced with all these difficulties,
it seems wise to abandon the idea of trying to understand how a deep neural
network makes a classification decision by understanding what the individual
hidden units do.</p>
      <p>By contrast, it is easy to explain how a decision tree makes any particular
classification because this depends on a relatively short sequence of decisions and
Copyright © 2018 for this paper by its authors. Copying permitted for private and academic purposes.
each decision is based directly on the input data. Decision trees, however, do not
usually generalize as well as deep neural nets. Unlike the hidden units in a neural
net, a typical node at the lower levels of a decision tree is only used by a very
small fraction of the training data so the lower parts of the decision tree tend to
overfit unless the size of the training set is exponentially large compared with
the depth of the tree.</p>
      <p>
        In this paper, we propose a novel way of resolving the tension between
generalization and interpretability. Instead of trying to understand how a deep
neural network makes its decisions, we use the deep neural network to train a
decision tree that mimics the input-output function discovered by the neural
network but works in a completely different way. If there is a large amount of
unlabelled data, the neural net can be used to create a much larger labelled
data set to train a decision tree, thus overcoming the statistical inefficiency of
decision trees. Even if unlabelled data is unavailable, it may be possible to use
recent advances in generative modeling [Goodfellow et al., 2014, King
        <xref ref-type="bibr" rid="ref15">ma and
Welling, 2013</xref>
        ] to generate synthetic unlabelled data from a distribution that is
close to the data distribution. Without using unlabelled data, it is still possible
to transfer the generalization abilities of the neural net to a decision tree by using
a technique called distillation [Hinton et al., 2015, Buciluˇa et al., 2006] and a
type of decision tree that makes soft decisions.
      </p>
      <p>At test time, we use the decision tree as our model. This may perform slightly
worse than the neural network but it will often be much faster and we now have
a model whose decisions we can explain and engage with directly.</p>
      <p>We start by describing the particular type of decision tree we use. This choice
was made to facilitate easy distillation of the knowledge acquired by a deep neural
net into a decision tree.
2</p>
    </sec>
    <sec id="sec-2">
      <title>The Hierarchical Mixture of Bigots</title>
      <p>We use soft binary decision trees trained with mini-batch gradient descent, where
each inner node i has a learned filter wi and a bias bi, and each leaf node `
has a learned distribution Q`. At each inner node, the probability of taking the
rightmost branch is:
pi(x) = (xwi + bi)
(1)
where x is the input to the model and is the sigmoid logistic function.</p>
      <p>This model is a hierarchical mixture of experts [Jordan and Jacobs, 1994], but
each expert is a actually a bigot who does not look at the data after training, and
therefore always produces the same distribution. The model learns a hierarchy of
filters that are used to assign each example to a particular bigot with a particular
path probability, and each bigot learns a simple, static distribution over the
possible output classes, k.</p>
      <p>Q`k =</p>
      <p>exp( `k)
Pk0 exp( `k0) ;
(2)
where Q` denotes the probability distribution at the `th leaf, and each ` is a
learned parameter at that leaf.</p>
      <p>In order to avoid very soft decisions in the tree, we introduced an inverse
temperature to the filter activations prior to calculating the sigmoid. Thus the
probability of taking the right branch at node i becomes pi(x) = ( (xwi + bi)).</p>
      <p>This model can be used to give a predictive distribution over classes in two
different ways, namely by using the distribution from the leaf with the greatest
path probability or averaging the distributions over all the leaves, weighted by
their respective path probabilities. If we take the predictive distribution from
the leaf with the greatest path probability, the explanation for that prediction is
simply the list of all the filters along the path from the route to the leaf together
with the binary activation decisions. If we average the leaf distributions weighted
by their respective path probabilities, we find that the model achieves marginally
better test accuracy, but this leads to an exponential increase in the complexity
of the explanation of the model’s predictive distribution on a particular case
because it involves the filters at all of the nodes. For this reason, for the remainder
of this paper when we refer to the output of the model, we will be referring to
the distribution at the leaf with the maximum path probability.</p>
      <p>We train the soft decision tree using a loss function that seeks to minimize
the cross entropy between each leaf, weighted by its path probability, and the
target distribution. For a single training case with input vector x and target
distribution T , the loss is:</p>
      <p>L(x) =
log</p>
      <p>X
`2LeafNodes</p>
      <p>P `(x) X Tk log Q`k
k
!
(3)
Where T is the target distribution and P `(x) is the probability of arriving at leaf
node ` given the input x.</p>
      <p>Unlike most decision trees, our soft decision trees use decision boundaries that
are not aligned with the axes defined by the components of the input vector. Also,
they are trained by first picking the size of the tree and then using mini-batch
gradient descent to update all of their parameters simultaneously, rather than
the more standard greedy approach that decides the splits one node at a time
[Friedman et al., 2001].
3</p>
    </sec>
    <sec id="sec-3">
      <title>Regularizers</title>
      <p>To avoid getting stuck at poor solutions during the training, we introduced a
penalty term that encouraged each internal node to make equal use of both
left and right sub-trees. Without this penalty, the tree tended to get stuck on
plateaus in which one or more of the internal nodes always assigned almost all
the probability to one of its sub-trees and the gradient of the logistic for this
decision was always very close to zero. The penalty is the cross entropy between
the desired average distribution 0:5; 0:5 for the two sub-trees and the actual
average distribution ; (1 ) where for node i is given by:
(4)
(5)
i =</p>
      <p>Px P i(x)pi(x)</p>
      <p>Px P i(x)
where P i(x) is the path probability from the root node to node i. The penalty
summed over all internal nodes is then:</p>
      <p>C =</p>
      <p>X
i2InnerNodes
0:5 log( i) + 0:5 log(1
i)
where is a hyper-parameter that determines the strength of the penalty and
is set prior to training. This penalty was based on the assumption that a tree
making fairly equal use of alternative sub-trees would usually be better suited to
any particular classification task and in practice it did increase accuracy. However,
this assumption is less and less valid as one descends the tree; a penultimate node
in the tree may only be responsible for two classes of input, in some non-equal
proportion, and penalizing the node for a non-equal split in this case could hurt
the accuracy of the model. We found that we achieved better test accuracy results
when the strength of the penalty decayed exponentially with the depth d of the
node in the tree so that it was proportional to 2 d.</p>
      <p>As one descends the tree, the expected fraction of the data that each node
sees in any given training batch decreases exponentially. This means that the
computation of the actual probabilities of using the two sub-trees becomes less
accurate. To counter this we can maintain an exponentially decaying running
average of the actual probabilities with a time window that is exponentially
proportional to the depth of the node. We found experimentally that we achieved
much better test accuracy by using both the exponential decay in the strength
of the penalty with depth and the exponential increase in the temporal scale of
the window used to compute the running average.
The number of total parameters at which our soft decision trees start to overfit is
typically less than the number of total parameters at which a multi-layer neural
network starts to overfit. This is because the lower nodes of the decision tree
only receive a very small fraction of the training data.</p>
      <p>This is reflected in performance on MNIST. With a soft decision tree of
depth 8 we were able to achieve a test accuracy of at most 94.45% when training
on the true targets. A neural net with two convolutional hidden layers and a
penultimate fully connected layer achieved a much better test accuracy of 99.21%.
We were then able to use the accuracy of the neural net to make a much better
soft decision tree by training with soft targets that were a composite of the true
labels and the predictions of the neural network. The soft decision tree trained
in this way achieved a test accuracy of 96.76% which is about halfway between
the neural net and the soft decision tree trained directly on the data.
5</p>
    </sec>
    <sec id="sec-4">
      <title>Explaining how a soft decision tree makes a classification</title>
      <p>
        The main motivation behind this work was to create a model whose behavior is
easy to explain; in order to fully understand why a particular example was given
a particular classification, one can simply examine all the learned filters along the
path between the root and the classification’s leaf node. The crux of this model
is that it does not rely on hierarchical features, it relies on hierarchical decisions
instead. The hierarchical features of a traditional neural network allow it to learn
robust and novel representations of the input space, but past a single level or
two, they become extremely difficult to engage with. Some current attempts at
explanations for neural networks rely on the use of gradient descent to find an
input that particularly excites a given neuron [Si
        <xref ref-type="bibr" rid="ref15">monyan et al., 2013</xref>
        , Erhan et al.,
2009], but this results is a single point on a manifold of inputs, meaning that
other inputs could yield the same pattern of neural excitement, and so it does
not reflect the entire manifold. Ribeiro et al. propose a strategy which relies on
fitting some explainable model which "acts over absence/presence of interpretable
components" to the behavior of a deep neural net around some area of interest
in the input space [Ribeiro et al., 2016]. This is accomplished by sampling from
the input space and querying the model around the area of interest and then
fitting an explainable model to the output of the model. This avoids the problem
of attempting to explain a particular output by visualizing a single point on a
manifold but introduces the problem of necessitating a new explainable model for
every area of interest in the input space, and attempting to explain changes in
the model’s behavior by first order changes in a discretized interpretation of the
input space. By relying on hierarchical decisions instead of hierarchical features
we side-step these problems, as each decision is made at a level of abstraction
that the reader can engage with directly.
6
      </p>
    </sec>
    <sec id="sec-5">
      <title>Other Data Sets and Results</title>
      <p>
        We tried this model on several other data sets, but focused on spatial input
for the sake of visualization. By first training a neural net and then using it to
provide soft targets for training a soft decision tree, with a tree of depth 8 we
were able to achieve a test accuracy of 80.60% on the Connect4 dataset [Lich
        <xref ref-type="bibr" rid="ref15">man,
2013</xref>
        ] comprised of board states of the popular child’s game connect 4 as input,
and the final outcome of the game (player 1 win, player 2 win, or tie) as the
target value. Without distilling from a neural net, the best test accuracy we
achieved was 78.63%. Other decision trees trained with gradient descent have
been applied to this dataset [Norouzi et al., 2015] but were only able to achieve a
maximum test accuracy of 76.50% at the equivalent depth of 8 and 77.45% at a
depth of 20. This provides an interesting example of the utility of an explainable
model - by examining the learned filters of the soft decision tree we are able to
learn something about the nature of the game. From examining the first learned
filter we can see that the game can be split into two distinct sub types of games
games where the players have placed pieces on the edges of the board, and games
where the players have placed pieces in the center of the board. These two sub
games progress in sufficiently different manners that it was beneficial for the
decision tree to split them at the root.
      </p>
      <p>
        We also ran our model on a non spatial dataset, namely the Letter dataset
[Lich
        <xref ref-type="bibr" rid="ref15">man, 2013</xref>
        ], which is comprised of primitive numerical attributes of capital
english characters. We were able to achieve a test accuracy of 78.0% with a tree
of depth 9 trained on the raw training data, and a test accuracy of 81.0% when
we distilled from an ensemble of neural nets that had a 95.9% test accuracy.
We have described a method for using a trained neural net to create a more
explicable model in the form of a soft decision tree which is trained by stochastic
gradient descent using the predictions of the neural net to give more informative
targets. The soft decision tree uses learned filters to make hierarchical decisions
based on an input example and ultimately select a particular static probability
distribution over classes as its output. This soft decision tree generalizes better
than one trained on the data directly, but performs worse than the neural net
which was used to provide the soft targets for training it. So if it is essential to be
able to explain why a model classifies a particular test case in a particular way,
we can use a soft decision tree, but we can still gain some of the benefits of deep
neural networks by using them to improve the training of this explicable model.
      </p>
    </sec>
  </body>
  <back>
    <ref-list>
      <ref id="ref1">
        <mixed-citation>
          <string-name>
            <given-names>Christian</given-names>
            <surname>Szegedy</surname>
          </string-name>
          , Wei Liu, Yangqing Jia,
          <string-name>
            <given-names>Pierre</given-names>
            <surname>Sermanet</surname>
          </string-name>
          , Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and
          <string-name>
            <given-names>Andrew</given-names>
            <surname>Rabinovich</surname>
          </string-name>
          .
          <article-title>Going deeper with convolutions</article-title>
          .
          <source>In Proceedings of the IEEE conference on computer vision and pattern recognition</source>
          , pages
          <fpage>1</fpage>
          -
          <lpage>9</lpage>
          ,
          <year>2015</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref2">
        <mixed-citation>
          <string-name>
            <given-names>Yonghui</given-names>
            <surname>Wu</surname>
          </string-name>
          , Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao,
          <string-name>
            <given-names>Qin</given-names>
            <surname>Gao</surname>
          </string-name>
          ,
          <string-name>
            <given-names>Klaus</given-names>
            <surname>Macherey</surname>
          </string-name>
          , et al.
          <article-title>Google's neural machine translation system: Bridging the gap between human and machine translation</article-title>
          .
          <source>arXiv preprint arXiv:1609.08144</source>
          ,
          <year>2016</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref3">
        <mixed-citation>
          <string-name>
            <given-names>Rafal</given-names>
            <surname>Jozefowicz</surname>
          </string-name>
          , Oriol Vinyals, Mike Schuster, Noam Shazeer, and
          <string-name>
            <given-names>Yonghui</given-names>
            <surname>Wu</surname>
          </string-name>
          .
          <article-title>Exploring the limits of language modeling</article-title>
          .
          <source>arXiv preprint arXiv:1602.02410</source>
          ,
          <year>2016</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref4">
        <mixed-citation>
          <string-name>
            <given-names>Alex</given-names>
            <surname>Graves</surname>
          </string-name>
          , Abdel-rahman
          <string-name>
            <surname>Mohamed</surname>
            , and
            <given-names>Geoffrey</given-names>
          </string-name>
          <string-name>
            <surname>Hinton</surname>
          </string-name>
          .
          <article-title>Speech recognition with deep recurrent neural networks</article-title>
          .
          <source>In Acoustics, speech and signal processing (icassp)</source>
          ,
          <source>2013 ieee international conference on</source>
          , pages
          <fpage>6645</fpage>
          -
          <lpage>6649</lpage>
          . IEEE,
          <year>2013</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref5">
        <mixed-citation>
          <string-name>
            <surname>Yann</surname>
            <given-names>LeCun</given-names>
          </string-name>
          , Yoshua Bengio, and
          <string-name>
            <given-names>Geoffrey</given-names>
            <surname>Hinton</surname>
          </string-name>
          .
          <article-title>Deep learning</article-title>
          .
          <source>Nature</source>
          ,
          <volume>521</volume>
          (
          <issue>7553</issue>
          ):
          <fpage>436</fpage>
          -
          <lpage>444</lpage>
          ,
          <year>2015</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref6">
        <mixed-citation>
          <string-name>
            <given-names>Ian</given-names>
            <surname>Goodfellow</surname>
          </string-name>
          , Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and
          <string-name>
            <given-names>Yoshua</given-names>
            <surname>Bengio</surname>
          </string-name>
          .
          <article-title>Generative adversarial nets</article-title>
          .
          <source>In Advances in neural information processing systems</source>
          , pages
          <fpage>2672</fpage>
          -
          <lpage>2680</lpage>
          ,
          <year>2014</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref7">
        <mixed-citation>
          <string-name>
            <surname>Diederik P Kingma and Max Welling</surname>
          </string-name>
          .
          <article-title>Auto-encoding variational bayes</article-title>
          .
          <source>arXiv preprint arXiv:1312.6114</source>
          ,
          <year>2013</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref8">
        <mixed-citation>
          <string-name>
            <given-names>Geoffrey</given-names>
            <surname>Hinton</surname>
          </string-name>
          , Oriol Vinyals, and
          <string-name>
            <given-names>Jeff</given-names>
            <surname>Dean</surname>
          </string-name>
          .
          <article-title>Distilling the knowledge in a neural network</article-title>
          .
          <source>arXiv preprint arXiv:1503.02531</source>
          ,
          <year>2015</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref9">
        <mixed-citation>
          <string-name>
            <surname>Cristian</surname>
            <given-names>Buciluaˇ</given-names>
          </string-name>
          , Rich Caruana, and
          <string-name>
            <surname>Alexandru</surname>
          </string-name>
          Niculescu-Mizil.
          <article-title>Model compression</article-title>
          .
          <source>In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining</source>
          , pages
          <fpage>535</fpage>
          -
          <lpage>541</lpage>
          . ACM,
          <year>2006</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref10">
        <mixed-citation>
          <string-name>
            <given-names>Michael I</given-names>
            <surname>Jordan</surname>
          </string-name>
          and
          <article-title>Robert A Jacobs</article-title>
          .
          <article-title>Hierarchical mixtures of experts and the em algorithm</article-title>
          .
          <source>Neural computation</source>
          ,
          <volume>6</volume>
          (
          <issue>2</issue>
          ):
          <fpage>181</fpage>
          -
          <lpage>214</lpage>
          ,
          <year>1994</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref11">
        <mixed-citation>
          <string-name>
            <given-names>Jerome</given-names>
            <surname>Friedman</surname>
          </string-name>
          , Trevor Hastie, and
          <string-name>
            <given-names>Robert</given-names>
            <surname>Tibshirani</surname>
          </string-name>
          .
          <article-title>The elements of statistical learning</article-title>
          , volume
          <volume>1</volume>
          . Springer series in statistics New York,
          <year>2001</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref12">
        <mixed-citation>
          <string-name>
            <given-names>Karen</given-names>
            <surname>Simonyan</surname>
          </string-name>
          , Andrea Vedaldi, and
          <string-name>
            <given-names>Andrew</given-names>
            <surname>Zisserman</surname>
          </string-name>
          .
          <article-title>Deep inside convolutional networks: Visualising image classification models and saliency maps</article-title>
          .
          <source>arXiv preprint arXiv:1312.6034</source>
          ,
          <year>2013</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref13">
        <mixed-citation>
          <string-name>
            <given-names>Dumitru</given-names>
            <surname>Erhan</surname>
          </string-name>
          , Yoshua Bengio, Aaron Courville, and
          <string-name>
            <given-names>Pascal</given-names>
            <surname>Vincent</surname>
          </string-name>
          .
          <article-title>Visualizing higher-layer features of a deep network</article-title>
          . University of Montreal,
          <volume>1341</volume>
          :
          <fpage>3</fpage>
          ,
          <year>2009</year>
          .
        </mixed-citation>
      </ref>
      <ref id="ref14">
        <mixed-citation>
          <string-name>
            <given-names>Marco</given-names>
            <surname>Túlio</surname>
          </string-name>
          <string-name>
            <surname>Ribeiro</surname>
          </string-name>
          ,
          <string-name>
            <given-names>Sameer</given-names>
            <surname>Singh</surname>
          </string-name>
          ,
          <string-name>
            <given-names>and Carlos</given-names>
            <surname>Guestrin</surname>
          </string-name>
          .
          <article-title>"why should I trust you?": Explaining the predictions of any classifier</article-title>
          .
          <source>CoRR, abs/1602.04938</source>
          ,
          <year>2016</year>
          . URL http://arxiv.org/abs/1602.04938.
        </mixed-citation>
      </ref>
      <ref id="ref15">
        <mixed-citation>
          <string-name>
            <surname>M. Lichman.</surname>
          </string-name>
          <article-title>UCI machine learning repository</article-title>
          ,
          <year>2013</year>
          . URL http://archive. ics.uci.edu/ml.
        </mixed-citation>
      </ref>
      <ref id="ref16">
        <mixed-citation>
          <string-name>
            <given-names>Mohammad</given-names>
            <surname>Norouzi</surname>
          </string-name>
          , Maxwell Collins, Matthew A Johnson, David J Fleet, and
          <string-name>
            <given-names>Pushmeet</given-names>
            <surname>Kohli</surname>
          </string-name>
          .
          <article-title>Efficient non-greedy optimization of decision trees</article-title>
          .
          <source>In Advances in Neural Information Processing Systems</source>
          , pages
          <fpage>1729</fpage>
          -
          <lpage>1737</lpage>
          ,
          <year>2015</year>
          .
        </mixed-citation>
      </ref>
    </ref-list>
  </back>
</article>