=Paper= {{Paper |id=Vol-3318/short26 |storemode=property |title=Fast optimization of weighted sparse decision trees for use in optimal treatment regimes and optimal policy design |pdfUrl=https://ceur-ws.org/Vol-3318/short26.pdf |volume=Vol-3318 |authors=Ali Behrouz,Mathias Lécuyer,Cynthia Rudin,Mango Seltzer |dblpUrl=https://dblp.org/rec/conf/cikm/BehrouzLRS22 }} ==Fast optimization of weighted sparse decision trees for use in optimal treatment regimes and optimal policy design== https://ceur-ws.org/Vol-3318/short26.pdf
Fast optimization of weighted sparse decision trees for use
in optimal treatment regimes and optimal policy design
Ali Behrouz1 , Mathias Lécuyer1 , Cynthia Rudin2 and Margo Seltzer1
1
    University of British Columbia, Vancouver, British Columbia, Canada
2
    Duke University, Durham, North Carolina, USA


                                          Abstract
                                          Sparse decision trees are one of the most common forms of interpretable models. While recent advances have produced
                                          algorithms that fully optimize sparse decision trees for prediction, that work does not address policy design, because the
                                          algorithms cannot handle weighted data samples. Specifically, they rely on the discreteness of the loss function, which means
                                          real-valued weights cannot be directly used. For example, none of the existing techniques produce policies that incorporate
                                          inverse propensity weighting on individual data points. We present three algorithms for efficient sparse weighted decision
                                          tree optimization. The first approach directly optimizes the weighted loss function but is computationally inefficient. Our
                                          second approach scales better by transforming weights to integer values and using data duplication to transform the weighted
                                          decision tree optimization problem into an unweighted, but larger, counterpart. Our third algorithm, which scales to much
                                          larger datasets, uses a randomized procedure that samples each data point with a probability proportional to its weight. We
                                          present theoretical bounds on the error of the two fast methods and show experimentally that these methods can be two
                                          orders of magnitude faster than the direct optimization of the weighted loss, without losing significant accuracy.

                                          Keywords
                                          Optimal Sparse Decision Trees, Interpretable Machine Learning, Explainability, Optimal Treatment Regimes



1. Introduction                                                                                        disease could be different. To create an optimal policy,
                                                                                                       we weight the loss from each patient and minimize the
Sparse decision trees are a leading class of interpretable sum of the weighted losses. While it is possible to con-
machine learning models that are commonly used for pol- struct a model using CART’s suboptimal greedy splitting
icy decisions [e.g., 1, 2, 3]. Historically, decision tree opti- procedure [5], the current fastest optimal decision tree
mization has involved greedy tree induction, where trees method, GOSDT [14], does not support this approach.
are built from the top down [4, 5, 6], but more recently                                                  We extend the framework of GOSDT-with-
there have been several approaches that fully optimize Guesses [13] to support weighted samples. GOSDT-
sparse trees to yield the best combination of performance with-Guesses produces sparse decision trees with
and interpretability [7, 8, 9, 10]. Optimization of sparse closeness-to-optimality guarantees in seconds or
optimal trees is NP-hard, and recent work has leveraged minutes for most datasets; we refer to this algorithm as
the fact that the loss takes on a discrete number of val- GOSDTwG. Our work introduces three approaches to
ues to provide a computational advantage [11, 12, 13, 14]. allow weighted samples.
However, if one were to try to create a policy tree or esti-                                              A key contributor to GOSDTwG’s performance is its
mate causal effects using one of these algorithms, it would use of bitvectors to compute the loss function. However,
become immediately apparent that such algorithms are introducing weights requires multiplying the weights by
not able to handle weighted data, because the weights this bitvector representation, which introduces a runtime
do not come in a small number of discrete values. This penalty of one to two orders of magnitude. We demon-
means that common weighting schemes, such as inverse strate this effect in our first approach. Our second ap-
propensity weighting or simply weighting some samples proach introduces a normalization and data duplication
more than others [15, 16], are not directly possible with technique to mitigate the slowdown due to real-valued
these algorithms.                                                                                      weights. Here, we transform the weights to small integer
   For example, consider developing a decision tree for values and then duplicate each sample by its transformed
describing medical treatment regimes. Here, the cost weight. Our third approach, which scales to much larger
for misclassification of patients in different stages of the sample sizes, uses a stochastic procedure, where we sam-
                                                                                                       ple each data point with a probability proportional to its
Advances in Interpretable Machine Learning and Artificial Intelligence
                                                                                                       weight. Our experimental results show that: (1) the sec-
’AIMLAI, October 21, 2022, Atlanta, GA
$ alibez@cs.ubc.ca (A. Behrouz); mathias.lecuyer@ubc.ca                                                ond and third techniques decrease run time by up to two
(M. Lécuyer); cynthia@cs.duke.edu (C. Rudin); mseltzer@cs.ubc.ca orders of magnitude relative to that achieved by the the
(M. Seltzer)                                                                                           direct approach, (2) we can bound the accuracy loss that
          © 2022 Copyright for this paper by its authors. Use permitted under Creative Commons License
    CEUR
          Attribution 4.0 International (CC BY 4.0).                                                   data duplication introduces; and (3) the weighted optimal
    Workshop
    Proceedings
                  http://ceur-ws.org
                  ISSN 1613-0073
                                       CEUR Workshop Proceedings (CEUR-WS.org)
decision tree technique can outperform natural baselines     𝑗-th feature of x𝑖 . To handle continuous features, we
in terms of running time, sparsity, and accuracy.            binarize them either by using all possible split points
                                                             to create dummy variables [25] or by using a subset of
                                                             these splits as done by McTavish et al. [13]. We let x̃, the
2. Related Work                                              binarized covariate matrix, be notated as x̃𝑖𝑗 ∈ {0, 1}.
Decision trees are one of the most popular forms of inter-
pretable models [17]. While full decision tree optimiza-     3.1. Objective
tion is NP-hard [18], it is possible to make assumptions,    Let 𝒯 be a decision tree that gives predictions {𝑦ˆ𝒯𝑖 }𝑁
                                                                                                                    𝑖=1 .
e.g., feature independence, that simplify the hard op-       The weighted loss of 𝒯 on the is:
timization to cases where greedy methods suffice [19].
However, these assumptions are unrealistic in practice.                                      𝑁
                                                                                   1        ∑︁
Other approaches [20, 21] assume that the data can be          ℒw (𝒯 , x̃, y) = ∑︀𝑁              1[𝑦𝑖 ̸= 𝑦ˆ𝒯𝑖 ] × 𝑤𝑖 . (1)
perfectly separated with zero error and use SAT solvers                             𝑖=1 𝑤𝑖 𝑖=1
to find optimal decision trees; however, real data are
generally not separable.                                        To achieve interpretability and prevent overfitting, we
   Recent work has addressed optimizing accuracy with        provide the option to use either soft sparsity regulariza-
soft or hard sparsity constraints on the tree size. Such     tion on the number of leaves, hard regularization on the
decision tree optimization problems can be formulated        tree depth, or both [see 13]:
using mixed integer programming (MIP) [9, 10, 12, 22,
                                                              minimize ℒw (𝒯 , x̃, y)+𝜆𝐻𝒯        𝑠.𝑡.depth(𝒯 ) ≤ 𝑑, (2)
23, 24], but MIP solvers tend to be slow. Several new             𝒯
algorithms use customized dynamic programming algo-
                                                             where 𝐻𝒯 is the number of leaves in 𝒯 and 𝜆 is a per-
rithms with branch-and-bound techniques to improve
                                                             leaf regularization parameter. We define 𝑅w (𝒯 , x̃, y) =
decision tree optimization scalability. In particular, an-
                                                             ℒw (𝒯 , x̃, y) + 𝜆𝐻𝒯 . We refer to 1[𝑦𝑖 ̸= 𝑦ˆ𝒯𝑖 ] as 𝐼𝑖 (𝒯 ),
alytical bounds combined with bitvector-based compu-
                                                             for simplicity. While in practice, depth constraints be-
tation efficiently reduce the search space and improve
                                                             tween 2 and 5 are usually sufficient, McTavish et al. [13]
runtime [25, 26, 27]. Lin et al. [14] extend this approach
                                                             provide theoretically-proven guidance to select a depth
to use dynamic programming, which leads to even better
                                                             constraint so that a single tree has the same expressive
scalability. Demirović et al. [28] introduce constraints
                                                             power (VC dimension) as an ensemble of smaller trees
on both depth and the number of nodes to improve scal-
                                                             (e.g., a random forest or a boosted decision tree). The
ability. Recently, McTavish et al. [13] proposed smart
                                                             parameter 𝜆 trades off between the weighted training
guessing strategies, based on knowledge gleaned from
                                                             loss and the number of leaves in the tree.
black-box models, that can be applied to any optimal
branch-and-bound-based decision tree algorithm to re-
duce the run time by multiple orders of magnitude. While     3.2. Learning Weighted Trees
these studies focus on improving runtime and accuracy,
                                                            We present three approaches for handling sample
they handle only uniform sample importance and do not
                                                            weights. The first is the direct approach, where we cal-
consider weighted data points. Our work neatly fills this
                                                            culate the weighted loss directly. Implementing this ap-
gap; our weighted objective function, data duplication
                                                            proach requires multiplying each misclassification by its
method, and sampling approach enable us to find near-
                                                            corresponding weight, which is computationally expen-
optimal decision trees quickly.
                                                            sive in any algorithm that uses bitvectors to optimize
   Several studies focus on learning tree- and list-based
                                                            loss computation. This overhead is due to replacing fast
treatment regimes from data [29, 30, 31, 32, 33, 34, 35].
                                                            bitvector operations with slower vector multiplications.
However, none of these methods fully optimize the policy,
                                                            The direct approach slows GOSDTwG down by two or-
because the techniques used for optimization were not
                                                            ders of magnitude. To avoid this computational penalty,
known when the work was done.
                                                            our second approach, data-duplication, transforms the
                                                            weights; specifically, we normalize, scale, and round the
3. Methodology                                              weights to small integer values. We then duplicate sam-
                                                            ples, where the number of duplicates is the value of the
Let {(x𝑖 , 𝑦𝑖 , 𝑤𝑖 )}𝑁
                     𝑖=1 represent our  training  dataset,  rounded weights, and use this larger unweighted dataset
where x𝑖 are 𝑀 -vectors of features, 𝑦𝑖 ∈ {0, 1, . . . , 𝐾} to learn the tree. This method avoids costly vector multi-
are labels, 𝑤𝑖 ∈ R≥0 is the weight associated with data plications and does not substantially increase run time
x𝑖 , and 𝑁 is the size of the dataset. Also, let x be the compared to the unweighted GOSDTwG. Finally, to scale
𝑁 × 𝑀 covariate matrix, w be the 𝑁 -vector of weights, to even larger datasets, we present a randomized proce-
and y be the 𝑁 -vector of labels, and let 𝑥𝑖𝑗 denote the dure, called weighted sampling, where we sample each
data point with a probability proportional to its weight. By introducing the above-mentioned lower bound guess,
This process introduces variance (not bias) and scales to we can now replace the lower bound of McTavish et al.
large numbers of samples.                                   [13] with our lower bound and proceed with branch-and-
                                                            bound. Their approach is provably close to optimal when
Direct Approach. We begin with the branch-and-bound
                                                            the reference model makes errors similar to those made
algorithm of McTavish et al. [13] and adapt it to support
                                                            in the optimal tree. Our approach using the weighted
weighted samples. Given a reference model 𝑇 , they prune
                                                            lower bound is also close to optimal. Let 𝑠𝑇,incorrect be the
the search space using three “guessing” techniques: (1)
                                                            set of observations incorrectly classified by the reference
guess how to transform continuous features into binary
                                                            model 𝑇 , i.e., 𝑠𝑇,incorrect = {𝑖|𝑦𝑖 ̸= 𝑦ˆ𝑇𝑖 }, and 𝑡g be a tree
features, (2) guess tree depth for depth-constrained mod-
                                                            returned from our lower-bound guessing algorithm.
els, and (3) guess tight lower bounds on the objective
for subsets of points to allow faster time-to-completion. Theorem 1. (Performance Guarantee). Let 𝑅(𝑡g , x̃, y)
It is straightforward to see that the first two techniques denote the objective of 𝑡g on the full binarized dataset (x̃, y)
apply directly to our weighted loss function. However, for some per-leaf penalty 𝜆. Then for any decision tree 𝑡
we need to adapt the third guessing technique to have that satisfies the same depth constraint 𝑑, we have:
an effective and tight lower bound for the weighted loss
function. Let 𝑦ˆ𝑇𝑖 be the predictions of a potentially com-
                                                                                          ⎛
                                                                                  1            ∑︁
plex reference model (e.g., a boosted decision tree model) 𝑅(𝑡g , x̃, y) ≤ ∑︀𝑁            ⎝                𝑤𝑖
on training observation 𝑖. The reference model is used                          𝑖=1 𝑤𝑖     𝑖∈𝑠𝑇 ,incorrect
as an upper bound on the performance of the sparse de-                                                          ⎞
cision tree we are optimizing. Let 𝑠𝑎 be the subset of                       +
                                                                                    ∑︁
                                                                                            1[𝑦𝑖 ̸= 𝑦ˆ𝑡𝑖 ] × 𝑤𝑖 ⎠ + 𝜆𝐻𝑡 .
training observations that satisfy a boolean assertion 𝑎:                       𝑖∈𝑠    𝑇 ,correct

        𝑠𝑎 := {𝑖 : 𝑎(x̃𝑖 ) = True, 𝑖 ∈ {1, ..., 𝑁 }}
                                                                 That is, the objective of the guessing model 𝑡g is no worse
    x̃(𝑠𝑎 ) := {x̃𝑖 : 𝑖 ∈ 𝑠𝑎 } , y(𝑠𝑎 ) := {𝑦𝑖 : 𝑖 ∈ 𝑠𝑎 }        than the union of errors of the reference model and tree 𝑡.
   w(𝑠𝑎 ) := {𝑤𝑖 : 𝑖 ∈ 𝑠𝑎 } .
                                                                    Hence, the model 𝑡g achieves a weighted objective that
Motivated by McTavish et al. [13], we define our guessed         is as good as the error of the reference model (which
lower bound on the achievable loss on subset 𝑠𝑎 as:              should be small) plus (something smaller than) the error
                      1     ∑︁                                   of the best possible tree of the same depth. The proof
  𝑙𝑏guess (𝑠𝑎 ) := ∑︀𝑁           1[𝑦𝑖 ̸= 𝑦ˆ𝑇𝑖 ] × 𝑤𝑖 + 𝜆 . (3)   appears in our supplementary material [36].
                     𝑖=1 𝑤𝑖 𝑖∈𝑠𝑎
                                                                 Motivation for Data Duplication. Surprisingly, in-
Eq. 3 is a lower bound guess for 𝑅w (𝑡, x̃(𝑠𝑎 ), y(𝑠𝑎 )),
                                                                 creasing the dataset size by replicating data is substan-
because we assume that the (possibly black box) reference
                                                                 tially faster than using the direct approach. Decision
model 𝑇 has a loss less than or equal to that of tree 𝑡 on
                                                                 tree optimization requires repeatedly evaluating the ob-
data 𝑠𝑎 , and we know that any tree has at least one node
                                                                 jective. Small improvements in that computation lead
(hence the regularization term’s lower bound of 𝜆 × 1).
                                                                 to a large improvement (possibly orders of magnitude)
   Accordingly, in the branch-and-bound algorithm, to
                                                                 in execution time. In the direct approach, computing
optimize the weighted loss function introduced in Equa-
                                                                 the objective (2) requires computing the inner product
tion 2, we consider a subproblem to be solved if we find
                                                                 w · ℐ, where ℐ𝑖 = 1[𝑦𝑖 ̸= 𝑦ˆ𝒯𝑖 ]. In the unweighted case,
a subtree that achieves an objective less than or equal to
                                                                 as all weights are 1, this computation can be performed
its 𝑙𝑏guess . If we find such a subtree, our training perfor-
                                                                 using bitvectors, which is extremely fast. In the weighted
mance will be at least as good as that of the reference
                                                                 case, we resort to standard inner products, which are two
model. For a subset of observations 𝑠𝑎 , we let 𝑡𝑎 be the
                                                                 orders of magnitude slower (see Section 4). The data-
subtree used to classify points in 𝑠𝑎 , and 𝐻𝑡𝑎 be the num-
                                                                 duplication approach allows us to use bitvectors as in the
ber of leaves in that subtree. We can define the subset’s
                                                                 unweighted case, preserving fast computation.
contribution to the objective as:
                                                                 Data-duplication Algorithm. The data-duplication
     𝑅w(𝑠𝑎 ) (𝑡𝑎 , x̃(𝑠𝑎 ), y(𝑠𝑎 ))                              algorithm is shown in Algorithm 1. We first normalize
             1        ∑︁
                                                                 all weights and scale them to (0, 1]. Given an integer,
     = ∑︀𝑁                 1[𝑦𝑖 ̸= 𝑦ˆ𝑡𝑖𝑎 ] × 𝑤𝑖 + 𝜆𝐻𝑡𝑎 .
          𝑖=1 𝑤𝑖 𝑖∈𝑠𝑎
                                                                 𝑝 > 0, we then multiply each normalized weight by 𝑝 and
                                                                 round to integers. We then duplicate each sample, x𝑖 , by
For any dataset partition 𝐴, where 𝑎 ∈ 𝐴 corresponds             its corresponding integer weight, 𝑤ˆ 𝑖 . Once the data are
to the data handled by a given subtree of 𝑡:                     duplicated, we can use any optimal decision tree technique.
     𝑅w (𝑡, x̃, y) =
                     ∑︁
                        𝑅w(𝑠𝑎 ) (𝑡𝑎 , x̃(𝑠𝑎 ), y(𝑠𝑎 )) .         Our experimentsshow that if we choose the value of 𝑝
                      𝑎∈𝐴
                                                                 appropriately, this method improves training runtime
 Algorithm 1: Data Duplication                                  Table 1
                                                                          Dataset       samples   features   binary features
  Input : Dataset (x, y, w), duplication factor 𝑝 < 100                   Lalonde         723       7             447
  Output : Duplicated dataset 𝑋˜,𝑦
                                 ˜                                        Broward         1954      38            588
   ˜ ← ∅; 𝑦                                                               Coupon          2653      21             87
1 𝑋       ˜ ← ∅;                                                          Diabetes        5000      34            532
         ˜ 𝑖 = round(𝑝 · ( ∑︀𝑁𝑤𝑖 ));
2 Define 𝑤                                                                COMPAS          6907      7             134
                                𝑤
                              𝑖=1   𝑖
                                                                          FICO           10459      23           1917
3 for 𝑥𝑖 ∈ x do                                                           Netherlands    20000      9            53890
4     for 𝑖 = 1, 2, . . . , 𝑤
                            ˜ 𝑖 do
5         𝑋˜ ←𝑋  ˜ ∪ {𝑥𝑖 }; 𝑦    ˜←𝑦
                                   ˜ ∪ {𝑦𝑖 };                   efficient, so we should duplicate data. When we use data
6 return 𝑋˜,𝑦˜                                                  duplication, the value of 𝜓 should also be small. The
                                                                proof is in our supplementary material [36].
significantly without losing too much accuracy. After        Weighted Sampling. When the ratio between the
data-duplication, there are no weights associated with       biggest and smallest weights is large, data duplication
samples, and we can use the fast bit-vector computations     might be inefficient if it requires creating many samples.
from the unweighted case.                                    To address this issue, we present a stochastic sampling
Correctness of Data Duplication. One might ask if process based on weights. Given an arbitrary amplifica-
the data duplication approach produces suboptimal so- tion number 𝑟, we sample 𝑆 = 𝑟 × 𝑁 data points such
                                                                                                            w𝑖
lutions, because its loss function is an approximation to that the probability of choosing x𝑖 is ∑︀𝑁       𝑖=1 w𝑖
                                                                                                                  . After this
the weighted loss. If the weights do not change very step, we can use any unweighted optimal decision tree
much when rounding to integers, the minimum of the algorithm on the sampled dataset.
data duplication algorithm’s objective is very close to the
                                                             Quality Guarantee of Weighted Sampling. Let ℒ̃(.)
minimum of the original weighted objective. Recall
                                                             be the loss function on the sampled dataset, it is not hard
                     1       ∑︁                              to see that E[ℒ̃] = ℒw , where ℒw is the value of the
       𝑅(𝑡) := ∑︀𝑁               𝑤𝑖 𝐼𝑖 (𝑡) + 𝜆#leaves.       misclassification (Eq. 1) on the weighted dataset. Based
                    𝑖=1 𝑤𝑖 𝑖
                                                             on this fact, we have the following theorem:
Define the objective with the approximate weights as
                                                             Theorem 3. Given a weighted dataset 𝐷                          =
                                                                                    𝑁
                                                             {(x     , 𝑦   , 𝑤   )}     , an arbitrary positive   real  num-
       ˜ (𝑡) := ∑︀ 1
                             ∑︁                                    𝑖     𝑖     𝑖    𝑖=1
       𝑅            𝑁
                                 𝑤
                                 ˜ 𝑖 𝐼𝑖 (𝑡) + 𝜆#leaves.      ber 𝑟 > 0, an arbitrary positive real number 𝜀 > 0, and
                    𝑖=1 𝑤 ˜𝑖 𝑖
                                                             a tree 𝒯 , if we sample 𝑆 = 𝑟 × 𝑁 data points from 𝐷,
                                                              ˜                       𝑆
By design, the rounding phase rounds amplified weights, 𝐷 = {(x̃𝑖 , 𝑦˜𝑖 )}𝑖=1 , we have:
ensuring that the absolute change in weights remains
                                                                                                                         2𝜀2
                                                                (︁                                      )︁           (︂       )︂
small. That is, we know that ‖w − w̃‖∞ ≤ 𝜖. Note P |ℒ̃(𝒯 , x̃, ỹ) − ℒw (𝒯 , x, y)| ≥ 𝜀 ≤ 2 exp −
that multiplying 𝑤𝑖 s by a scalar cannot change the value                                                                 𝑆
of the objective function. Accordingly, normalizing or
scaling weights by 𝑝 does not change the value of 𝑅(𝑡). 4. Experiments
Therefore, without loss of generality, we can assume that
𝑤𝑖 s are weights right before rounding.                      Our evaluation addresses the following questions: (1)
                                                             When is the direct approach more efficient than data-
Theorem 2. Let 𝑡* be a minimizer of the objective as 𝑡* ∈ duplication and weighted sampling? (2) In practice, how
arg min𝑡 𝑅(𝑡), and 𝒯˜ be a minimizer of the approximate well do the second and third proposed methods perform
loss function as 𝒯˜ ∈ arg min𝑡 𝑅  ˜ (𝑡). If ‖w − w̃‖∞ ≤ 𝜖, relative to the direct approach? (3) How sparse and fast
                                                             are our weighted models relative to state-of-the-art opti-
|𝑅(𝑡* )−𝑅 ˜ (𝒯˜ )| ≤ 𝑚𝑎𝑥{ (𝜁 − 1)𝜓 + 𝜖 , (𝜂 − 1)𝜓 + 𝜖 }, mal decision trees? (4) How can our approach be used
                                    𝜁              𝜂         for policy making? We use sparsity as a proxy for inter-
                            {︁ }︁                     {︁ }︁ pretability, because it can be quantified, thus providing
                              𝑤𝑖
where 𝜂 = max1≤𝑖≤𝑁 𝑤          ˜𝑖
                                  , 𝜁 = max1≤𝑖≤𝑁 𝑤      ˜𝑖
                                                        𝑤𝑖
                                                           , an objective means of comparision [17].
           max𝑖 {𝑤𝑖 ,𝑤
                     ˜ 𝑖}
and 𝜓 = min𝑖 {𝑤𝑖 ,𝑤˜ 𝑖 } .
                                                                4.1. Datasets
    In other words, the rounded solution provably will not
lose substantial performance, as long as both the additive      We use seven publicly available real-world datasets; Ta-
and multiplicative changes in weights due to rounding           ble 1 shows sizes of these datasets: The Lalonde dataset
are small. The value of 𝜂 and 𝜁 are usually small and near      [37, 38], Broward [39], the coupon dataset, which was
1, if the original weights do not have extreme imbalances.      collected on Amazon Mechanical Turk via a survey [40],
If the value of 𝜓 is large, then the direct approach is more    Diabetes [41], which is a health care related dataset, the
                             without data duplication with data duplication.
                                                                               because it uses only subsets of the data. Data duplication,
 Running Time (s)




                                                  Running Time (s)
                         3
                    10                                               103       while slower than weighted sampling, is faster than the
                                                                               direct method, without losing much accuracy.
                    102                                              102
                                                             Sparsity vs. accuracy. The dotted line and round and di-
                    101                                              101
                                                             amond shapes in Figures 2 and 3(a) illustrate the accuracy-
        101  102   103  104        101  102   103  104       sparsity tradeoff for different decision tree models (the
              𝑞 (%)                      𝑞 (%)
Figure 1: Training time of the model with and without data black line represents accuracy for GBDT). GOSDTwG pro-
duplication on different machines.                           duces excellent training and test accuracy with a small
                                                             number of leaves, and, compared to other decision tree
Fair Isaac (FICO) credit risk dataset [42] from the Explain- models, achieves higher accuracy for every level of spar-
able ML Challenge, and the COMPAS [43] and Nether- sity. Results of other datasets can be found in [36].
lands [44] datasets, which are recidivism datasets. Unless
stated otherwise, we use inverse propensity score with Training time vs. test accuracy. Figures 3(b) and 3(c)
respect to one of the features as our weights.               show the training time and accuracy for different meth-
   We ran the experiments with different depth bounds ods. While the training times of GOSDTwG and CART
and regularization; each point in each plot shows the are almost the same, GOSDTwG achieves the highest
results for one setting. A full description of the data sets training and test accuracy in almost all cases. As DL8.5
and configurations appear in our supplement [36].            timed out at one hour on all datasets except Lalonde, it
                                                             did not reach optimality and was outperformed by both
                                                             CART and GOSDTwG. Results of other datasets can be
4.2. Baselines                                               found in our supplement [36].
We compared our methods with the following baseline Lalonde Case Study. The Lalonde dataset is from the Na-
models: (1) CART [5], (2) DL8.5 [45], and (3) Gradient tional Supported Work Demonstration [38, 37], a labour
Boosted Decision Trees (GBDT) [46, 47]. CART and market experiment in which participants were random-
GBDT can both handle weighted datasets, so we use ized between treatment (9-12 month on-the-job training)
their default weighted implementation as the baselines. and control groups. Each unit 𝑈 has a pre-treatment
                                                                                                𝑖
As DL8.5 does not supported weighted datasets, we use covariate vector 𝑋 and observed assigned treatment 𝑍 .
                                                                                𝑖                                      𝑖
the data-duplication approach with it.                       Let 𝑌 1 be the outcome if unit 𝑈 received the treatment
                                                                                    𝑖                            𝑖
                                                                               and 𝑌𝑖2 be the outcome if it was not treated. When a unit
4.3. Results                                                                   is treated, we do not observe the outcome had it not been
                                                                               treated and vice versa. We use the MALTS model [48]
Data duplication. We begin by demonstrating how                                to estimate these missing values by matching, producing
much the direct approach penalizes runtime relative to                         an estimate of the conditional average treatment effect.
data-duplication. We use the unweighted FICO dataset                           We classify participants into three groups—“should be
and randomly pick 𝑞% of the original 𝒮 data points. We                         treated,” “should be treated if budget allows,” and “should
assign each selected point 𝑤𝑒𝑖𝑔ℎ𝑡 = 2 by duplicating                           not be treated” -– based on their conditional average treat-
                                          𝑞
it, producing a dataset of size (1 + 100    ) × 𝑁 , where                      ment effect estimate. Then we labelled the data points as
𝑁 is the size of the original data set, |𝑆|. We then com-                      2, 1, and 0 if the estimated treatment effect is larger than
pare runtimes for this data-duplicated data set and the                        2000, between −5000 and 2000, and less than −5000,
original dataset in which we assign 𝑤𝑒𝑖𝑔ℎ𝑡 = 2 to the                          respectively. We define the penalty for each misclassi-
selected samples and 𝑤𝑒𝑖𝑔ℎ𝑡 = 1 to the remaining sam-                          fication as (i) cost = 0 if correctly classified, (ii) cost =
ples. We run this experiment on two machines, with                             200 + 3 × 𝑎𝑔𝑒 if label = 0 and misclassified, (iii) cost =
different processors and amounts of memory, to show                            100 + 3 × 𝑎𝑔𝑒 if label = 1 and misclassified, and (iv) cost
the consistency of the results on different machines. The                      = 300 if label = 2 and misclassified. We linearly scale the
full machine descriptions appear in our supplementary                          above costs to the range from 1 to 100, and in the case
material [36]. Figure 1 shows that when the size of the du-                    of data-duplication, we round them to integers and treat
plicated dataset is less than 100 times the original dataset,                  them as weights of the dataset. Figure 4 shows the tree
the data-duplication approach is always faster.                                produced by GOSDTwG with a depth limit of 3; trees
                                                                               with other depth limits appear in [36].
Comparison of our approaches. We next compare
the relative accuracy achieved using all three approaches.
The star-shaped points in Figures 2 and 3 show the result 5. Conclusions
of this comparison. These results suggest a trade-off
between accuracy and running time. Weighted sampling To find the optimal weighted decision tree, we first sug-
is the fastest approach, but it has the worst accuracy, gest directly optimizing a weighted loss function. To
                                             Training Accuracy vs Number of Leaves                                           Training Accuracy vs Number of Leaves                                               Training Accuracy vs Number of Leaves                                    Training Accuracy vs Number of Leaves
                                                            (Lalonde)                                                                      (Broward)                                                                            (Compas)                                                                 (Coupon)
                                     94                                                                                                                                                                  72                                                                          78
                                                                                                                        72

             Training Accuracy (%)




                                                                                                Training Accuracy (%)




                                                                                                                                                                                 Training Accuracy (%)




                                                                                                                                                                                                                                                             Training Accuracy (%)
                                     93                                                                                 70                                                                               70                                                                          76
                                     92                                                                                 68                                                                                                                                                           74
                                                                                                                                                                                                         68
                                     91                                                                                 66
                                                                                                                                                                                                                                                                                     72
                                     90                                                                                 64                                                                               66
                                                                                                                                                                                                                                                                                     70
                                                                                                                        62
                                     89                                                                                                                                                                  64
                                                                                                                        60                                                                                                                                                           68
                                                                 101                                                                                  101                                                                                101                                                                    101
                                                     Number of Leaves (log scale)                                                      Number of Leaves (log scale)                                                       Number of Leaves (log scale)                                           Number of Leaves (log scale)
                                          Training Accuracy vs Number of Leaves                                              Training Accuracy vs Number of Leaves                                            Training Accuracy vs Number of Leaves
                                                          (FICO)                                                                          (Netherlands)                                                                     (Diabetes)
                              76                                                                                 76
      Training Accuracy (%)




                                                                                         Training Accuracy (%)




                                                                                                                                                                          Training Accuracy (%)
                                                                                                                                                                                                  80
                              74                                                                                 74

                              72                                                                                                                                                                  75
                                                                                                                 72
                              70                                                                                 70                                                                               70
                              68
                                                                      101                                                                             101                                                                               101
                                                    Number of Leaves (log scale)                                                      Number of Leaves (log scale)                                                       Number of Leaves (log scale)
Figure 2: Sparsity vs. training accuracy: All methods but CART and GBDT use guessed thresholds. DL8.5 frequently times
out, so there are fewer markers for it. GOSDTwG achieves the highest accuracy for every level of sparsity.
                                               Test Accuracy vs Number of Leaves                                                   Test Accuracy vs Number of Leaves                                               Test Accuracy vs Number of Leaves
                                                           (Lalonde)                                                                           (Broward)                                                                       (Diabetes)
                                        93                                                                              72
                                        92                                                                                                                                                                  82
                                                                                                                        70
                Test Accuracy (%)




                                                                                                   Test Accuracy (%)




                                                                                                                                                                                    Test Accuracy (%)
                                        91                                                                                                                                                                  80
                                                                                                                        68
                                        90                                                                                                                                                                  78
                                                                                                                        66
                                                                                                                                                                                                            76
                                        89                                                                              64
                                                                                                                                                                                                            74
                                        88                                                                              62
                                                                                                                                                                                                            72
                                        87                                                                              60
                                                                 101                                                                                   101                                                                               101
                                                     Number of Leaves (log scale)                                                       Number of Leaves (log scale)                                                      Number of Leaves (log scale)

                                                                                                                                                 (a) Sparsity vs. test accuracy:
                                                                                                                                                                                                                     Training Accuracy vs Run Time
                                             Training Accuracy vs Run Time (Lalonde)                                           Train Accuracy vs Run Time (Broward)                                                            (Diabetes)
                                        94                                                                              72                                                                                  82
                Training Accuracy (%)




                                                                                                                                                                                    Training Accuracy (%)
                                                                                                   Train Accuracy (%)




                                        93                                                                              70                                                                                  80
                                        92                                                                              68                                                                                  78
                                        91                                                                              66                                                                                  76
                                        90                                                                              64                                                                                  74
                                                                                                                        62                                                                                  72
                                        89
                                                                                                                        60                                                                                  70
                                                                  101                                                          100          101         102         103                                            101               102               103
                                                        Training Time (log scale)                                                         Training Time (log scale)                                                          Training Time (log scale)

                                                                                                                                       (b) Training time vs. training accuracy:
                                                                                                                                                                                                                         Test Accuracy vs Run Time
                                               Test Accuracy vs Run Time (Lalonde)                                             Test Accuracy vs Run Time (Broward)                                                               (Diabetes)
                                        93                                                                              72
                                        92                                                                                                                                                                  82
                                                                                                                        70
                Test Accuracy (%)




                                                                                                   Test Accuracy (%)




                                                                                                                                                                                    Test Accuracy (%)




                                        91                                                                                                                                                                  80
                                                                                                                        68
                                        90                                                                                                                                                                  78
                                                                                                                        66
                                                                                                                                                                                                            76
                                        89                                                                              64
                                                                                                                                                                                                            74
                                        88                                                                              62
                                                                                                                                                                                                            72
                                        87                                                                              60
                                                                  101                                                          100          101         102         103                                            101               102               103
                                                        Training Time (log scale)                                                         Training Time (log scale)                                                          Training Time (log scale)

                                          (c) Training time vs. test accuracy:
Figure 3: GOSDTwG achieves the highest test accuracy for almost every level of sparsity. While CART is the fastest algorithm,
GOSDTwG uses its additional runtime to produce models with higher accuracy and that generalize better.
                                                                      𝑒𝑑𝑢𝑐𝑎𝑡𝑖𝑜𝑛 ≤ 11.5
                                                                                                                                                                                                  sample an unweighted dataset from our weighted dataset.
                                                𝑎𝑔𝑒 ≤ 31.5                                                               𝑟𝑒75 ≤ 897.409                                                           Our results suggest a trade-off of accuracy and runtime
                                                                                                                                                                                                  among these approaches.
   𝑟𝑒75 ≤ 21497.509                                         ℎ𝑖𝑠𝑝𝑎𝑛𝑖𝑐 ≤ 0.5           𝑎𝑔𝑒 ≤ 18.5                                        𝑟𝑒75 ≤ 21497.509


    𝑐𝑙𝑎𝑠𝑠                                    𝑐𝑙𝑎𝑠𝑠          𝑐𝑙𝑎𝑠𝑠         𝑐𝑙𝑎𝑠𝑠     𝑐𝑙𝑎𝑠𝑠                                𝑐𝑙𝑎𝑠𝑠           𝑐𝑙𝑎𝑠𝑠       𝑐𝑙𝑎𝑠𝑠                                        Acknowledgments
        1                                       0        We acknowledge the following grant support: NI-
                                                                1            0       1                                         2            1           0
Figure 4: The tree generated by GOSDTwG (depth limit 3) H/NIDA under grant number DA054994 and NSF un-
on the Lalonde dataset.                                  der grant number IIS-2147061. This research was
                                                         enabled in part by support provided by WestGrid
improve efficiency, we present the data-duplication ap- (https://www.westgrid.ca) and The Digital Research Al-
proach, which rounds all weights to integers and then liance (https://alliancecan.ca/en). We acknowledge the
duplicates each sample by its weight. To further improve support of the Natural Sciences and Engineering Re-
efficiency, we present a stochastic process in which we search Council of Canada (NSERC).
References                                                        eralized and scalable optimal sparse decision trees,
                                                                  in: International Conference on Machine Learning
 [1] D. Ernst, P. Geurts, L. Wehenkel, Tree-based batch           (ICML), 2020, pp. 6150–6160.
     mode reinforcement learning, Journal of Machine         [15] A. Linden, P. R. Yarnold, Estimating causal effects
     Learning Research 6 (2005) 503–556.                          for survival (time-to-event) outcomes by combin-
 [2] A. Silva, M. Gombolay, T. Killian, I. Jimenez, S.-H.         ing classification tree analysis and propensity score
     Son, Optimization methods for interpretable dif-             weighting, Journal of Evaluation in Clinical Prac-
     ferentiable decision trees applied to reinforcement          tice 24 (2018) 380–387.
     learning, in: International Conference on Artifi-       [16] D. A. Cieslak, N. V. Chawla, Learning decision trees
     cial Intelligence and Statistics (AISTATS), 2020, pp.        for unbalanced data, in: Joint European Conference
     1855–1865.                                                   on Machine Learning and Knowledge Discovery in
 [3] Y. Dhebar, K. Deb, S. Nageshrao, L. Zhu, D. Filev,           Databases, Springer, 2008, pp. 241–256.
     Interpretable-AI policies using evolutionary nonlin-    [17] C. Rudin, C. Chen, Z. Chen, H. Huang, L. Semen-
     ear decision trees for discrete action systems, arXiv        ova, C. Zhong, Interpretable machine learning:
     e-print arXiv:2009.09521 (2020).                             Fundamental principles and 10 grand challenges,
 [4] J. R. Quinlan, C4.5: Programs for Machine Learning,          Statistics Surveys 16 (2022) 1–85.
     Morgan Kaufmann, 1993.                                  [18] H. Laurent, R. L. Rivest, Constructing optimal bi-
 [5] L. Breiman, J. Friedman, C. J. Stone, R. A. Olshen,          nary decision trees is np-complete, Information
     Classification and Regression Trees, CRC press,              Processing Letters 5 (1976) 15–17.
     1984.                                                   [19] A. R. Klivans, R. A. Servedio, D. Ron, Toward at-
 [6] D. Dobkin, T. Fulton, D. Gunopulos, S. Kasif,                tribute efficient learning of decision lists and par-
     S. Salzberg, Induction of shallow decision trees,            ities., Journal of Machine Learning Research 7
     IEEE Trans. on Pattern Analysis and Machine Intel-           (2006).
     ligence (1997).                                         [20] N. Narodytska, A. Ignatiev, F. Pereira, J. Marques-
 [7] A. Farhangfar, R. Greiner, M. Zinkevich, A fast              Silva, I. RAS, Learning optimal decision trees with
     way to produce near-optimal fixed-depth decision             sat, in: 27th International Joint Conference on
     trees, in: Proceedings of the 10th International Sym-        Artificial Intelligence (IJCAI), 2018, pp. 1362–1368.
     posium on Artificial Intelligence and Mathematics       [21] H. Hu, M. Siala, E. Hébrard, M.-J. Huguet, Learning
     (ISAIM-2008), 2008.                                          optimal decision trees with maxsat and its inte-
 [8] S. Nijssen, E. Fromont, Mining optimal decision              gration in adaboost, in: 29th International Joint
     trees from itemset lattices, in: 13th ACM SIGKDD             Conference on Artificial Intelligence and the 17th
     International Conference on Knowledge Discovery              Pacific Rim International Conference on Artificial
     and Data Mining, 2007, pp. 530–539.                          Intelligence (IJCAI-PRICAI), 2020.
 [9] D. Bertsimas, J. Dunn, Optimal classification trees,    [22] S. Verwer, Y. Zhang, Learning optimal classifica-
     Machine Learning 106 (2017) 1039–1082.                       tion trees using a binary linear program formula-
[10] O. Günlük, J. Kalagnanam, M. Li, M. Menickelly,              tion, in: AAAI Conference on Artificial Intelligence,
     K. Scheinberg, Optimal decision trees for categori-          volume 33, 2019, pp. 1625–1632.
     cal data via integer programming, Journal of Global     [23] C. Rudin, S. Ertekin, Learning customized and op-
     Optimization (2021) 1–28.                                    timized lists of rules with mathematical program-
[11] S. Aghaei, M. J. Azizi, P. Vayanos,            Learn-        ming, Mathematical Programming C (Computation)
     ing optimal and fair decision trees for non-                 10 (2018) 659–702.
     discriminative decision-making, Proceedings of          [24] M. G. Vilas Boas, H. G. Santos, L. H. d. C. Mer-
     the AAAI Conference on Artificial Intelligence 33            schmann, G. Vanden Berghe, Optimal decision trees
     (2019) 1418–1426. URL: https://ojs.aaai.org/index.           for the algorithm selection problem: integer pro-
     php/AAAI/article/view/3943. doi:10.1609/aaai.                gramming based approaches, International Transac-
     v33i01.33011418.                                             tions in Operational Research 28 (2021) 2759–2781.
[12] S. Aghaei, A. Gómez, P. Vayanos, Strong optimal         [25] X. Hu, C. Rudin, M. Seltzer, Optimal sparse de-
     classification trees, arXiv preprint arXiv:2103.15965        cision trees, in: Advances in Neural Information
     (2021).                                                      Processing Systems, 2019, pp. 7267–7275.
[13] H. McTavish, C. Zhong, R. Achermann, I. Karimalis,      [26] E. Angelino, N. Larus-Stone, D. Alabi, M. Seltzer,
     J. Chen, C. Rudin, M. Seltzer, Fast sparse decision          C. Rudin, Learning certifiably optimal rule lists for
     tree optimization via reference ensembles, in: AAAI          categorical data, Journal of Machine Learning Re-
     Conference on Artificial Intelligence, volume 36,            search 18 (2018) 1–78. URL: http://jmlr.org/papers/
     2022.                                                        v18/17-716.html.
[14] J. Lin, C. Zhong, D. Hu, C. Rudin, M. Seltzer, Gen-     [27] C. Chen, C. Rudin, An optimization approach to
     learning falling rule lists, in: International Confer-        measurement on hospital readmission rates: Anal-
     ence on Artificial Intelligence and Statistics (AIS-          ysis of 70,000 clinical database patient records,
     TATS), 2018.                                                  BioMed Research International 2014 (2014) 781670.
[28] E. Demirović, A. Lukina, E. Hebrard, J. Chan,                 URL: https://doi.org/10.1155/2014/781670. doi:10.
     J. Bailey, C. Leckie, K. Ramamohanarao, P. J.                 1155/2014/781670.
     Stuckey, Murtree: Optimal classification trees via       [42] FICO, Google, Imperial College London,
     dynamic programming and search, arXiv preprint                MIT, University of Oxford, UC Irvine, UC
     arXiv:2007.12652 (2020).                                      Berkeley,      Explainable Machine Learning
[29] H. Lakkaraju, C. Rudin, Learning cost-effective               Challenge,            https://community.fico.com/s/
     and interpretable treatment regimes, in: Artificial           explainable-machine-learning-challenge, 2018.
     intelligence and statistics, PMLR, 2017, pp. 166–175.    [43] J. Larson, S. Mattu, L. Kirchner, J. Angwin, How
[30] Y. Zhang, E. B. Laber, A. Tsiatis, M. Davidian, Using         we analyzed the COMPAS recidivism algorithm,
     decision lists to construct interpretable and parsi-          ProPublica (2016).
     monious treatment regimes, Biometrics 71 (2015)          [44] N. Tollenaar, P. Van der Heijden, Which method
     895–904.                                                      predicts recidivism best?: a comparison of statisti-
[31] F. Wang, C. Rudin, Causal falling rule lists, arXiv           cal, machine learning and data mining predictive
     preprint arXiv:1510.05189 (2015).                             models, Journal of the Royal Statistical Society:
[32] E. B. Laber, Y.-Q. Zhao, Tree-based methods for               Series A (Statistics in Society) 176 (2013) 565–584.
     individualized treatment regimes, Biometrika 102         [45] G. Aglin, S. Nijssen, P. Schaus, Learning opti-
     (2015) 501–514.                                               mal decision trees using caching branch-and-bound
[33] Y. Cui, R. Zhu, M. Kosorok, Tree based weighted               search, in: AAAI Conference on Artificial Intelli-
     learning for estimating individualized treatment              gence, volume 34, 2020, pp. 3146–3153.
     rules with censored data, Electronic Journal of          [46] Y. Freund, R. E. Schapire, A desicion-theoretic gen-
     Statistics 11 (2017) 3927–3953.                               eralization of on-line learning and an application to
[34] K. Doubleday, H. Zhou, H. Fu, J. Zhou, An algorithm           boosting, in: Conference on Computational Learn-
     for generating individualized treatment decision              ing Theory, Springer, 1995, pp. 23–37.
     trees and random forests, Journal of Computational       [47] J. H. Friedman, Greedy function approximation:
     and Graphical Statistics 27 (2018) 849–860.                   a gradient boosting machine, Annals of Statistics
[35] Y. Sun, L. Wang, Stochastic tree search for estimat-          (2001) 1189–1232.
     ing optimal dynamic treatment regimes, Journal           [48] H. Parikh, C. Rudin, A. Volfovsky, MALTS: Match-
     of the American Statistical Association 116 (2021)            ing after learning to stretch, arXiv preprint
     421–432.                                                      arXiv:1811.07415 (2018).
[36] A. Behrouz, M. Lecuyer, C. Rudin, M. Seltzer, Fast
     optimization of weighted sparse decision trees for
     use in optimal treatment regimes and optimal policy
     design, 2022. URL: https://arxiv.org/abs/2210.06825.
     doi:10.48550/ARXIV.2210.06825.
[37] R. Lalonde, Evaluating the econometric evaluations
     of training programs with experiment data, Ameri-
     can Economic Review 76 (1986) 604–20.
[38] R. H. Dehejia, S. Wahba, Causal effects in nonex-
     perimental studies: Reevaluating the evaluation of
     training programs, Journal of the American statis-
     tical Association 94 (1999) 1053–1062.
[39] C. Wang, B. Han, B. Patel, F. Mohideen, C. Rudin, In
     pursuit of interpretable, fair and accurate machine
     learning for criminal recidivism prediction, Journal
     of Quantitative Criminology (2022).
[40] T. Wang, C. Rudin, F. Doshi-Velez, Y. Liu, E. Klampfl,
     P. MacNeille, A bayesian framework for learning
     rule sets for interpretable classification, Journal of
     Machine Learning Research 18 (2017) 1–37. URL:
     http://jmlr.org/papers/v18/16-003.html.
[41] B. Strack, J. P. DeShazo, C. Gennings, J. L. Olmo,
     S. Ventura, K. J. Cios, J. N. Clore, Impact of hba1c