PseudoNAM: A Pseudo Value Based Interpretable Neural Additive Model for Survival Analysis Md Mahmudur Rahman, Sanjay Purushotham Department of Information Systems, University of Maryland Baltimore County, Baltimore, Maryland, USA mrahman6@umbc.edu, psanjay@umbc.edu Abstract els (Rahman et al. 2021; Zhao and Feng 2020; Ishwaran et al. 2008; Katzman et al. 2018) achieve more accurate Deep learning models have achieved the start-of-the-art per- formance in survival analysis as they can handle censor- predictions but may require specialized objective functions ing while learning complex nonlinear hidden representa- to handle censoring (Lee et al. 2018). Moreover, many of tions directly from the raw data. However, the covariate ef- these approaches are not interpretable or explainable, which fects on survival probabilities are difficult to explain using makes them opaque and unsuitable for medical applications. deep learning models. To address this challenge, we propose To address the limitations of the existing methods, we PseudoNAM - an interpretable model which uses pseudo val- propose a pseudo value based neural additive model, called ues to efficiently handle censoring and uses neural additive PseudoNAM, which directly models the complex non- networks to capture the nonlinearity in the covariates of the linear time-varying effect of the covariate on the survival survival data. In particular, PseudoNAM uses neural addi- function. Our PseudoNAM uses neural additive models tive models to jointly learn a linear combination of neural (NAM) (Agarwal et al. 2020) to jointly learn a linear com- networks corresponding to each covariate and identifies the effect of the individual covariate on the output, and thus, is bination of neural networks corresponding to each covariate inherently interpretable. We show that our PseudoNAM out- and determines the magnitude of covariates’ effect on the puts can be used in other survival models such as random survival outcome, and thus, is inherently interpretable. The survival forests to obtain improved survival prediction per- neural networks for each feature in the PseudoNAM are in- formance. Our experiments on three real-world survival anal- dependent, and thus, can provide the individual feature con- ysis datasets demonstrate that our proposed models achieve tribution towards output survival prediction. Like in NAM, similar or better performance (in terms of C-index and Brier we sum up the individual feature contributions (neural net- scores) than the state-of-the-art survival methods. We show- work outputs), followed by a logit transformation using case that PseudoNAM provides overall feature importance sigmoid activation function to predict the survival proba- scores and feature-level interpretations (covariate effect on bility at different time points. We show different types of in- survival risk) for survival predictions at different time points. terpretations from PseudoNAM, including 1) the mean fea- ture contributions to the survival probability predictions at Introduction different time points (overall feature importance scores) and Survival analysis (Kleinbaum and Klein 2010), a well- 2) feature-level interpretations which show the time-varying studied problem, aims to estimate the risk of a subject’s fail- covariate effect on the survival predictions. ure from an event, such as death due to breast cancer at a Our experiments on three real-world datasets demonstrate particular time point. One key challenge in survival analy- that PseudoNAM performs similar or better than the state- sis is the presence of censored subjects for whom the actual of-the-art survival analysis models while providing inter- survival times remain unknown. A good survival analysis pretable results. We further improve the performance of model should handle censoring, accurately discriminate the PseudoNAM by proposing PseudoNRSF, a random sur- predicted risks, and should be interpretable. Traditional sta- vival forest approach that takes as input the learned outputs tistical survival analysis models such as Cox Proportional of individual neural networks from PseudoNAM and pre- Hazard models (Cox 1972), regression models based on dicts the survival probabilities. We show that PseudoNRSF pseudo-observations (Andersen, Klein, and Rosthøj 2003; achieves state-of-the-art results in terms of c-index and Brier Andersen and Pohar Perme 2010) are interpretable but are scores. less accurate and limited by strong assumptions on the un- derlying stochastic process, such as linearity, parametric, Related Works and proportional hazards assumptions. Recent survival ap- Cox-based statistical and deep learning survival models proaches based on machine learning and deep learning mod- (Cox 1972; Faraggi and Simon 1995; Katzman et al. 2018; Copyright © 2021 for this paper by its authors. Use permitted under Kvamme and Borgan 2019) are widely studied for analyzing Creative Commons License Attribution 4.0 International (CC BY time-to-event data. However, these models make strong pro- 4.0) portional hazard and linearity assumptions that may not hold for real data, thus leading to less accurate survival results. function (e.g., logit link function), β is the bias and each Machine learning models such as Random survival fi (.) is parametrized by a neural network. y(t|X) is the forests (Ishwaran et al. 2008) and multi-task logistic regres- pseudo values for survival probability at time t in the pres- sion (MTLR) (Yu et al. 2011; Fotso 2018) relax some of ence of covariates. Each of the networks learn the complex these assumptions and outperform statistical-based methods. shape function of a specific covariate, and all the networks Recently proposed deep learning models (Lee et al. 2018; are trained jointly. A n × p matrix of p baseline covariates Nagpal, Li, and Dubrawski 2021) and conditional gener- with n individuals are used as input in the input layer. Out- ative adversarial networks (Chapfuwa et al. 2018) achieve put layer returns the survival probabilities at M evaluation state-of-the-art results for time-to-event analysis. However, time points. PseudoNAM model provides interpretability these methods require either making assumptions on the un- because it jointly trains a set of neural networks correspond- derlying stochastic process or design a specialized objective ing to each individual covariate and returns the covariates’ function to handle censoring. Moreover, these methods lack contribution scores to the output (i.e., the output of the neu- interpretability which is required in the medical domain. To ral networks) for all each covariate and for the M evalua- address the censoring challenge, (Zhao and Feng 2020; Rah- tion time points. Then we sum up the contribution scores of man et al. 2021) have respectively proposed pseudo value all covariates followed by applying a sigmoid activation based deep learning models for survival and competing risk function to get the final output, i.e., the survival probabilities analysis. However, even these methods are not directly in- at the M time points. The non-overlapping neural networks terpretable and rely on off-the-shelf explainable AI methods for individual covariates allow to identify the individual co- such as LRP (Montavon et al. 2019) for providing explana- variate effect on the survival probabilities. tions. 𝐟𝟏 (𝐗 𝟏 ) Our Proposed Models FC Layer Output Layer 1 To address the censoring and interpretability challenges X1 ... FC Layer of existing survival analysis approaches, in this work, we propose two interpretable pseudo value based deep learning FC Layer {y11 , y12 , … , y1M } models, PseudoNAM and PseudoNRSF. Before describing our models in detail, we will briefly introduce pseudo values. 𝐟𝟐 (𝐗 𝟐 ) FC Layer Output Layer 2 What are Pseudo values? Pseudo values for the survival X2 𝜎 𝐲 ... FC Layer probability are derived from the non-parametric population- FC Layer {y21 , y22 , … , y2M } based Kaplan-Meier (KM) estimator, an approximately un- biased estimator of the survival probability under indepen- ... Add dent censoring (Andersen et al. 2012). For the ith subject, 𝐟𝐩 (𝐗 𝐩 ) 𝜎 Sigmoid a Jackknife pseudo value, based on the KM estimate of the FC Layer Output Layer p survival probability (Klein et al. 2008), is computed at time Xp β Bias ... FC Layer horizon t∗ as FC Layer {yp1 , yp2 , … , ypM } ∗ ∗ −i ∗ Ŝi (t ) = nŜ(t ) − (n − 1)Ŝ (t ) β ∗ where, Ŝ(t ) is the Kaplan-Meier estimate of the survival probability at time t∗ based on a sample with n subjects and Ŝ −i (t∗ ) is the Kaplan-Meier estimate of the survival Figure 1: Architecture of PseudoNAM. X = probability at time t∗ based on a leave-one-out sample {X1 , X2 , .., Xp } is a p dimensional vector of covariates. with (n − 1) subjects, obtained by omitting the ith subject. fi (Xi ) is the neural network corresponding to covariate Xi Pseudo values are calculated for both uncensored subjects and β is the bias. Sigmoid (σ) is worked as inverse logit link and censored subjects (incompletely observed) at a specified function. y is the output, i.e., survival probability at M time time point. points. FC Layer means Fully Connected Layer. PseudoNAM: Inspired by the success of pseudo value PseudoNRSF: While PseudoNAM provides inter- based deep models, DNNSurv (Zhao and Feng 2020) and pretable predictions; its performance is limited by the DeepPseudo (Rahman et al. 2021) to handle censoring, we NAM model architecture. To improve the performance of propose PseudoNAM - a multi-output neural additive model PseudoNAM model and to obtain global and local inter- which predicts pseudo values for survival risk analysis. pretations like Random survival forests (RSF) (Ishwaran PseudoNAM, shown in Figure 1, learns non-linear repre- et al. 2008), we propose PseudoNRSF - a two-stage deep sentations in the data and uses pseudo values to handle cen- learning model. In the first stage, PseudoNAM model is soring efficiently. PseudoNAM has the following form: used to learn the individual feature contribution scores for predicting pseudo values. In the second stage, these g(E[y(t|X)]) = β + f1 (X1 ) + f2 (X2 ) + + fp (Xp ) (1) learned feature contribution scores are input to an RSF with Here, Xi = (Xi1 , Xi2 , ..., Xip ) is a p-dimensional covari- the goal of directly predicting survival probabilities. Thus, ate vector for ith individual; i = 1, 2, .., n. g(.) is the link PseudoNRSF returns the subject-specific survival proba- Feature Importance Plot for METABRIC Dataset Mean Feature Contribution MKI67 EGFR PGR ERBB2 Hormone Radio- Chemo- ER- Age at treatment therapy therapy positive diagnosis Figure 2: Mean individual feature contributions on survival probabilities at different time points on METABRIC dataset. Here, 10th Perc., 20th Perc. are representing 10th percentile and 20th percentile of the survival time distribution at which we get the PseudoNAM model predictions. METABRIC: This data (Katzman et al. 2018)1 contains patients’ gene expressions and clinical variables for breast cancer survival prediction. SUPPORT: This dataset (Knaus et al. 1995) is from the Vanderbilt University study to estimate survival of 9,105 seriously ill hospitalized patients. WHAS: This dataset was collected to examine the effects of a patient’s factors on acute myocardial infraction (MI) survival (Hosmer and Lemeshow 2002). Figure 3: Importance of feature (mean weight ± sd) on the survival probability predictions measured by PseudoNRSF Implementation Details: The (ground-truth) pseudo on METABRIC dataset. values for survival probabilities are obtained using the jackknife function of R package prodlim at each evaluation time point (separately for training and validation bilities at different time points as output, and the predictions sets). We performed stratified 5-fold cross-validation so are interpretable due to the use of RSF. PseudoNRSF has that the ratio of censored and uncensored subjects remained the interpretation property of the RSF. We can easily get the the same in each fold. We jointly train our PseudoNAM’s effect of each covariate (importance score) on the overall feature networks based on an early stopping criterion and survival probability. However, using the PseudoNAM, we choose the best model based on the model’s performance on can see the change of covariate effect on survival probabili- validation data. Each feature network consists of 3 hidden ties at different time points. layers with a number of units [128, 64, 32]. We used relu activation function in the hidden layer of each covariate’s neural network and tanh activation function in the output Experiments layer of the neural networks. In the final output layer, we We conducted experiments on three real-world datasets to sum up the output of individual feature neural networks and answer the following questions: a) how well our proposed use the sigmoid activation function to get the survival models perform compared to the state-of-the-art survival probability at 10th , 20th , 30th , 40th , 50th , 60th percentile models? b) how well can our PseudoNAM explain their of the maximum survival time of the training data. We did predictions? not perform hyperparameter tuning. We set the learning rate 0.0001, output penalty coefficient 0.001, weight decay Datasets: Table 1 shows statistics of the following coefficient 0.000001, dropout rate 0.0, and feature dropout datasets. 1 https://github.com/jaredleekatzman/DeepSurv Feature Contribution EGFR ER-positive ERBB2 MKI67 PGR Feature Contribution Age at Chemotherapy Hormone Radiotherapy diagnosis treatment Figure 4: Feature-level contributions to survival probability at different time points (10th , 20th , 30th , 40th , 50th , 60th percentile of survival time distribution) for individual features in METABRIC dataset. The darker the brown bars indicates higher density of the data. Age at diagnosis, EGFR, ERBB2, MKI67, PGR are continuous valued features, while others are discrete valued. Table 1: Descriptive Statistics of the three Real-World Survival Datasets No. of No. of No. of No. of Event Time Censoring Time Dataset Observation uncensored (%) censored (%) features Min Max Mean Median Min Max Mean Median METABRIC 1904 1103(57.9) 801(42.1) 9 0.1 355.2 100.0 85.9 0.1 337.0 159.6 158.0 SUPPORT 8873 6036 (68.0) 2837 (32.0) 14 3.0 1944.0 205.5 57.0 344.0 2029.0 1059.9 918.0 WHAS 1638 690 (42.1) 948 (57.9) 5 0.1 1965.9 696.7 515.5 371.0 1999.0 1298.9 1347.5 rate 0. We used Adam optimizer with batch size 128 during sion [MTLR] (Yu et al. 2011) training and used Penalized Mean Squared Error loss func- tion to train our models. For performance metrics, we used • Deep Learning models: DNNSurv (Zhao and Feng (a) time-dependent concordance index (Antolini, Boracchi, 2020), DeepHit (Lee et al. 2018), DeepSurv (Katz- and Biganzoli 2005) adjusted with an inverse propensity of man et al. 2018), CoxTime (Kvamme, Borgan, and censoring estimate for evaluating the discriminative-ability, Scheel 2019), Deep Survival Machine [DSM] (Nagpal, and (b) Integrated IPCW Brier Score (denoted as Brier Li, and Dubrawski 2021), Piecewise Constant Hazard Score) (Gerds and Schumacher 2006) metric for evaluating [PCHazard] (Kvamme and Borgan 2019), and our pro- the predictive-ability. To encourage reproducibility, the posed models: PseudoNAM and PseudoNRSF. source codes for our proposed models are available at this link: https://github.com/umbc-sanjaylab/PseudoNAM SA Results and Discussion Model Comparisons: We compared the following Table 2 shows the performance comparison of the survival survival analysis models: models based on time dependent concordance index and Brier scores. From this table, we see that our PseudoNAM • Statistical models: Cox Proportional Hazard Model obtains similar or comparable performance to other survival [CoxPH] (Cox 1972) analysis models, while PseudoNRSF outperforms all the survival models on the WHAS dataset, and obtains similar • Machine learning models: Random Survival Forest performance as the state-of-the-art models on the other two [RSF] (Ishwaran et al. 2008), Multi-task Logistic Regres- datasets. We notice that the independent neural networks for Table 2: Model comparisons of the performance metrics (mean and 95% confidence interval) evaluated on survival datasets Time-dependent Concordance Index PseudoNRSF PseudoNAM DNNSurv CoxPH CoxTime DeepHit DeepSurv DSM MTLR PCHazard RSF METABRIC 0.645±0.038 0.616±0.025 0.617±0.014 0.622±0.013 0.660±0.055 0.655±0.045 0.641±0.017 0.616±0.040 0.550±0.043 0.614±0.041 0.616±0.058 SUPPORT 0.619±0.019 0.613±0.017 0.581±0.009 0.568±0.016 0.616±0.012 0.593±0.012 0.589±0.009 0.595±0.005 0.550±0.024 0.589±0.022 0.638±0.010 WHAS 0.865±0.038 0.740±0.022 0.721±0.018 0.739±0.013 0.783±0.027 0.851±0.038 0.787±0.030 0.739±0.013 0.618±0.104 0.685±0.038 0.768±0.041 Brier Score PseudoNRSF PseudoNAM DNNSurv CoxPH CoxTime DeepHit DeepSurv DSM MTLR PCHazard RSF METABRIC 0.171±0.005 0.245±0.013 0.243±0.010 0.313±0.020 0.168±0.011 0.178±0.022 0.165±0.013 0.249±0.020 0.225±0.024 0.201±0.014 0.296±0.016 SUPPORT 0.196±0.01 0.207±0.007 0.221±0.002 0.206±0.005 0.192±0.008 0.211±0.008 0.198±0.006 0.212±0.003 0.263±0.019 0.225±0.018 0.190±0.007 WHAS 0.099±0.010 0.267±0.022 0.290±0.029 0.234±0.029 0.136±0.018 0.140±0.054 0.132±0.018 0.201±0.005 0.162±0.028 0.141±0.009 0.206±0.033 individual covariates limits PseudoNAM to learn the shared (i.e., the outputs of the individual neural networks of effect on the survival probability at different times, thus re- PseudoNAM) on survival probability at different time sulting in comparable but not the best results. points (i.e., time-varying covariate effect on survival pre- dictions) for the METABRIC dataset. Here x-axis shows Model Intepretations the feature values, and the y-axis shows their contribu- tions. In other words, this plot provides feature-level in- The main advantage of our PseudoNAM models is that they terpretations. For example, the survival probability for can provide interpretations. Here, we discuss the two ways the feature age at diagnosis at all the time points of interpreting the PseudoNAM model predictions: overall starts decreasing after 65 years; and we see that the feature feature importance scores and feature-level interpretations. chemotherapy is biased to the patients who did not re- PseudoNAM first learns the individual feature contri- ceive chemotherapy since the density is much higher for this butions for pre-specified time points (here, we choose group (darker brown bar). The plot also shows that the model 10th , 20th , ..., 60th percentile of the event horizon). Then predicted a decrease in survival probability for a few patients we sum up these feature contributions followed by the who received chemotherapy, especially at later time points. sigmoid transformation to get the final survival proba- bilities at the pre-specified time points. Figure 2 shows the Why PseudoNAM is suited for healthcare domain? overall feature importance scores measured as mean in- dividual feature contributions on the survival probability at As shown in Table 2, our PseudoNAM models obtain good the pre-specified time points for the METABRIC dataset. predictive and discriminative performance on all the survival We see that the features can have a positive or negative datasets. Moreover, using our models, one can visualize each impact (overall effect) on survival probability predictions covariates’ contribution to the survival probability. There- at different time points for breast cancer patients. For ex- fore, PseudoNAM helps to identify the potential risk factors ample, the covariates such as MKI67, radiotherapy, for an event, such as death due to breast cancer. The visual- and chemotherapy have positive feature contribu- ization of the feature-level interpretations can be a step to- tions at the initial time points (10th percentile), which means wards transparency in the deep learning models, which can that they influence better survival outcomes (higher survival inform clinical decision-making and perhaps lead to trust in probabilities). However, at later time points (such as 60th the model. Thus, PseudoNAM models with high predictive- percentile), these features have negative feature contri- ability and inherent interpretability could be well-suited for butions - meaning they result in mortality. This is expected survival analysis in the healthcare domain. because the survival probability remains higher at initial time points, and it decreases over time. Therefore, the treat- Conclusion ment like chemotherapy fails to reduce the risk of death In this paper, we proposed interpretable pseudo value-based at later time points, and the older people (age at diagnosis) deep learning approaches PseudoNAM and PseudoNRSF are at greater mortality risk. to model the nonlinear time-varying covariate effect on sur- Figure 3 shows the permutation feature importance, vival predictions. Our proposed models use 1) pseudo val- which is measured by observing how random re-shuffling ues to handle censoring and 2) neural additive networks to of each covariate influences model performance. We capture the complex nonlinear relationships and to obtain in- use eli52 library to compute the feature importance terpretable predictions. Empirical results show that our pro- for PseudoNRSF model. We observe that age at posed models achieve similar or better performance than the diagnosis has the highest importance on the survival state-of-the-art survival methods. Our PseudoNAM model probability predictions and ER-positive has the lowest provides both overall feature importance scores and feature- feature importance. level interpretations of predicted survival probabilities at dif- Figure 4 shows the individual feature contribution ferent time points. For future work, we study and compare the interpretability of our proposed models with other para- 2 https://github.com/eli5-org/eli5 metric survival approaches. Acknowledgements Knaus, W. A.; Harrell, F. E.; Lynn, J.; Goldman, L.; Phillips, R. S.; Connors, A. F.; Dawson, N. V.; Fulkerson, W. J.; This work is partially supported by grant IIS–1948399 Califf, R. M.; Desbiens, N.; et al. 1995. The SUPPORT from the US National Science Foundation and grant prognostic model: Objective estimates of survival for seri- 80NSSC21M0027 from the National Aeronautics and Space ously ill hospitalized adults. Annals of internal medicine . Administration. Kvamme, H.; and Borgan, Ø. 2019. Continuous and References discrete-time survival prediction with neural networks. arXiv preprint arXiv:1910.06724 . Agarwal, R.; Frosst, N.; Zhang, X.; Caruana, R.; and Hinton, Kvamme, H.; Borgan, Ø.; and Scheel, I. 2019. Time-to- G. E. 2020. Neural additive models: Interpretable machine event prediction with neural networks and Cox regression. learning with neural nets. arXiv:2004.13912 . arXiv preprint arXiv:1907.00825 . Andersen, P. K.; Borgan, O.; Gill, R. D.; and Keiding, Lee, C.; Zame, W.; Yoon, J.; and van der Schaar, M. 2018. N. 2012. Statistical models based on counting processes. Deephit: A deep learning approach to survival analysis with Springer Science & Business Media. competing risks. In Proceedings of the AAAI Conference on Andersen, P. K.; Klein, J. P.; and Rosthøj, S. 2003. Gen- Artificial Intelligence, volume 32. eralised linear models for correlated pseudo-observations, Montavon, G.; Binder, A.; Lapuschkin, S.; Samek, W.; and with applications to multi-state models. Biometrika 90(1). Müller, K.-R. 2019. Layer-wise relevance propagation: an Andersen, P. K.; and Pohar Perme, M. 2010. Pseudo- overview. In Explainable AI: Interpreting, Explaining and observations in survival analysis. Statistical methods in med- Visualizing Deep Learning. Springer. ical research 19(1): 71–99. Nagpal, C.; Li, X. R.; and Dubrawski, A. 2021. Deep sur- Antolini, L.; Boracchi, P.; and Biganzoli, E. 2005. A time- vival machines: Fully parametric survival regression and dependent discrimination index for survival data. Statistics representation learning for censored data with competing in medicine 24(24): 3927–3944. risks. IEEE Journal of Biomedical and Health Informatics . Rahman, M. M.; Matsuo, K.; Matsuzaki, S.; and Pu- Chapfuwa, P.; Tao, C.; Li, C.; Page, C.; Goldstein, B.; Duke, rushotham, S. 2021. DeepPseudo: Pseudo Value Based Deep L. C.; and Henao, R. 2018. Adversarial time-to-event mod- Learning Models for Competing Risk Analysis. In Proceed- eling. In International Conference on Machine Learning, ings of the AAAI Conference on Artificial Intelligence, vol- 735–744. PMLR. ume 35, 479–487. Cox, D. R. 1972. Regression models and life-tables. Journal Yu, C.-N.; Greiner, R.; Lin, H.-C.; and Baracos, V. 2011. of the Royal Statistical Society: Series B (Methodological) . Learning patient-specific cancer survival distributions as a Faraggi, D.; and Simon, R. 1995. A neural network model sequence of dependent regressors. Advances in Neural In- for survival data. Statistics in medicine 14(1): 73–82. formation Processing Systems 24: 1845–1853. Fotso, S. 2018. Deep neural networks for survival anal- Zhao, L.; and Feng, D. 2020. Deep neural networks for sur- ysis based on a multi-task framework. arXiv preprint vival analysis using pseudo values. IEEE journal of biomed- arXiv:1801.05512 . ical and health informatics 24(11): 3308–3314. Gerds, T. A.; and Schumacher, M. 2006. Consistent estima- tion of the expected Brier score in general survival models with right-censored event times. Biometrical Journal 48(6). Hosmer, D. W.; and Lemeshow, S. 2002. Applied survival analysis: regression modelling of time to event data. Wiley. Ishwaran, H.; Kogalur, U. B.; Blackstone, E. H.; Lauer, M. S.; et al. 2008. Random survival forests. The annals of applied statistics 2(3): 841–860. Katzman, J. L.; Shaham, U.; Cloninger, A.; Bates, J.; Jiang, T.; and Kluger, Y. 2018. DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC medical research methodology 18(1). Klein, J. P.; Gerster, M.; Andersen, P. K.; Tarima, S.; and Perme, M. P. 2008. SAS and R functions to compute pseudo- values for censored data regression. Computer methods and programs in biomedicine . Kleinbaum, D. G.; and Klein, M. 2010. Survival analysis. Springer.