How to interpret output of .predict() from fitted scikit-survival model in python?

I'm confused how to interpret the output of .predict from a fitted CoxnetSurvivalAnalysis model in scikit-survival. I've read through the notebook Intro to Survival Analysis in scikit-survival and the API reference, but can't find an explanation. Below is a minimal example of what leads to my confusion:

import pandas as pd
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.linear_model import CoxnetSurvivalAnalysis

# load data
data_X, data_y = load_veterans_lung_cancer()

# one-hot-encode categorical columns in X
categorical_cols = ['Celltype', 'Prior_therapy', 'Treatment']

X = data_X.copy()
for c in categorical_cols:
    dummy_matrix = pd.get_dummies(X[c], prefix=c, drop_first=False)
    X = pd.concat([X, dummy_matrix], axis=1).drop(c, axis=1)

# display final X to fit Cox Elastic Net model on
del data_X

so here's the X going into the model:

   Age_in_years  Celltype  Karnofsky_score  Months_from_Diagnosis  \
0          69.0  squamous             60.0                    7.0   
1          64.0  squamous             70.0                    5.0   
2          38.0  squamous             60.0                    3.0   

  Prior_therapy Treatment  
0            no  standard  
1           yes  standard  
2            no  standard  

...moving on to fitting model and generating predictions:

# Fit Model
coxnet = CoxnetSurvivalAnalysis()
coxnet.fit(X, data_y)    

# What are these predictions?    
preds = coxnet.predict(X)

preds has same number of records as X, but their values are wayyy different than the values in data_y, even when predicted on the same data they were fit on.




So what exactly are preds? Clearly .predict means something pretty different here than in scikit-learn, but I can't figure out what. The API Reference says it returns "The predicted decision function," but what does that mean? And how do I get to the predicted estimate in months yhat for a given X? I'm new to survival analysis so I'm obviously missing something.

Max Power Avatar asked Nov 13 '17 22:11

Max Power

1 Answers

I posted this question on github, though the author renamed the issue question.

I got some helpful explanation of what the predict output is, but still am not sure how to get to a set of predicted survival times, which is what I really want. Here's a couple helpful explanations from that github thread:

predictions are risk scores on an arbitrary scale, which means you can 
usually only determine the sequence of events, but not their exact time.

-sebp (library author)

It [predict] returns a type of risk score. Higher value means higher
risk of your event (class value = True)...You were probably looking
for a predicted time. You can get the predicted survival function with
estimator.predict_survival_function as in the example 00
notebook...EDIT: Actually, I’m trying to extract this but it’s been a
bit of a pain to munge


There's more explanation at the github thread, though I wasn't really able to follow all of it. I need to play around with predict_survival_function and predict_cumulative_hazard_function and see if I can get to a set of predictions for most likely survival time by row in X, which is what I really want.

I'm not going to accept this answer here, in case anyone else has a better one.

Max Power Avatar answered Oct 23 '22 13:10

Max Power