Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to interpret output of .predict() from fitted scikit-survival model in python?

I'm confused how to interpret the output of .predict from a fitted CoxnetSurvivalAnalysis model in scikit-survival. I've read through the notebook Intro to Survival Analysis in scikit-survival and the API reference, but can't find an explanation. Below is a minimal example of what leads to my confusion:

import pandas as pd
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.linear_model import CoxnetSurvivalAnalysis

# load data
data_X, data_y = load_veterans_lung_cancer()

# one-hot-encode categorical columns in X
categorical_cols = ['Celltype', 'Prior_therapy', 'Treatment']

X = data_X.copy()
for c in categorical_cols:
    dummy_matrix = pd.get_dummies(X[c], prefix=c, drop_first=False)
    X = pd.concat([X, dummy_matrix], axis=1).drop(c, axis=1)

# display final X to fit Cox Elastic Net model on
del data_X
print(X.head(3))

so here's the X going into the model:

   Age_in_years  Celltype  Karnofsky_score  Months_from_Diagnosis  \
0          69.0  squamous             60.0                    7.0   
1          64.0  squamous             70.0                    5.0   
2          38.0  squamous             60.0                    3.0   

  Prior_therapy Treatment  
0            no  standard  
1           yes  standard  
2            no  standard  

...moving on to fitting model and generating predictions:

# Fit Model
coxnet = CoxnetSurvivalAnalysis()
coxnet.fit(X, data_y)    

# What are these predictions?    
preds = coxnet.predict(X)

preds has same number of records as X, but their values are wayyy different than the values in data_y, even when predicted on the same data they were fit on.

print(preds.mean()) 
print(data_y['Survival_in_days'].mean())

output:

-0.044114643249153422
121.62773722627738

So what exactly are preds? Clearly .predict means something pretty different here than in scikit-learn, but I can't figure out what. The API Reference says it returns "The predicted decision function," but what does that mean? And how do I get to the predicted estimate in months yhat for a given X? I'm new to survival analysis so I'm obviously missing something.

like image 706
Max Power Avatar asked Nov 13 '17 22:11

Max Power


People also ask

What is predict () Sklearn?

The Sklearn 'Predict' Method Predicts an Output That being the case, it provides a set of tools for doing things like training and evaluating machine learning models. What is this? And it also has tools to predict an output value, once the model is trained (for ML techniques that actually make predictions).

What does the predict function do in Python?

predict() : given a trained model, predict the label of a new set of data. This method accepts one argument, the new data X_new (e.g. model. predict(X_new) ), and returns the learned label for each object in the array.

How do you evaluate a survival model?

The most frequently used evaluation metric of survival models is the concordance index (c index, c statistic). It is a measure of rank correlation between predicted risk scores and observed time points that is closely related to Kendall's τ.

What is the difference between fit and predict?

fit() method will fit the model to the input training instances while predict() will perform predictions on the testing instances, based on the learned parameters during fit .


1 Answers

I posted this question on github, though the author renamed the issue question.

I got some helpful explanation of what the predict output is, but still am not sure how to get to a set of predicted survival times, which is what I really want. Here's a couple helpful explanations from that github thread:

predictions are risk scores on an arbitrary scale, which means you can 
usually only determine the sequence of events, but not their exact time.

-sebp (library author)

It [predict] returns a type of risk score. Higher value means higher
risk of your event (class value = True)...You were probably looking
for a predicted time. You can get the predicted survival function with
estimator.predict_survival_function as in the example 00
notebook...EDIT: Actually, I’m trying to extract this but it’s been a
bit of a pain to munge

-pavopax.

There's more explanation at the github thread, though I wasn't really able to follow all of it. I need to play around with predict_survival_function and predict_cumulative_hazard_function and see if I can get to a set of predictions for most likely survival time by row in X, which is what I really want.

I'm not going to accept this answer here, in case anyone else has a better one.

like image 120
Max Power Avatar answered Oct 23 '22 13:10

Max Power