from collections import defaultdict
import requests
import h5py
import pandas as pd
from ._base import _BaseDataset, DATA_DIR
DEEPSURV_URL = "https://raw.githubusercontent.com/jaredleekatzman/DeepSurv/master/experiments/data/"
DEEPSURV_DATASETS = {
'support': "support/support_train_test.h5",
'metabric': "metabric/metabric_IHC4_clinical_train_test.h5",
'gbsg': "gbsg/gbsg_cancer_train_test.h5",
'whas': "whas/whas_train_test.h5"
}
[docs]
def download_from_deepsurv(name, covariates=None):
url = DEEPSURV_URL + DEEPSURV_DATASETS[name]
path = DATA_DIR / f"{name}.h5"
with requests.Session() as s:
r = s.get(url)
with open(path, 'wb') as f:
f.write(r.content)
data = defaultdict(dict)
with h5py.File(path) as f:
for ds in f:
for array in f[ds]:
data[ds][array] = f[ds][array][:]
path.unlink()
assert 'valid' not in data, f"Dataset {name} has a validation set."
df_train = _make_df(data['train'], covariates)
df_test = _make_df(data['test'], covariates)
return df_train, df_test
def _make_df(data, covariates=None):
x = data['x']
t = data['t']
e = data['e']
if covariates is None:
covariates = ['x'+str(i) for i in range(x.shape[1])]
df = (pd.DataFrame(x, columns=covariates)
.assign(duration=t)
.assign(event=e))
return df
class _DeepSurvDataset(_BaseDataset):
duration_col = 'duration'
event_col = 'event'
def _download(self):
df_train, df_test = download_from_deepsurv(self.name, self.covariates)
df_train.to_feather(self.path_train)
df_test.to_feather(self.path_test)
def load(self, split=False):
df_train, df_test = None, None
if not self.path_train.exists() and not self.path_test.exists():
print(f"Downloading dataset '{self.name}' from DeepSurv ...")
self._download()
print(f"Done")
df_train = pd.read_feather(self.path_train)
df_test = pd.read_feather(self.path_test)
df_train = self._label_cols_at_end(df_train)
df_test = self._label_cols_at_end(df_test)
for cat in self.categorical_covariates:
df_train[cat] = df_train[cat].astype('category')
df_test[cat] = df_test[cat].astype('category')
if split:
print('Using default train-test split (used in DeepSurv paper).')
return df_train, df_test
else:
return pd.concat([df_train, df_test], ignore_index=True)
[docs]
class SUPPORT(_DeepSurvDataset):
"""
Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT).
A study of survival for seriously ill hospitalized adults.
This is the processed data set used in the DeepSurv paper (Katzman et al. 2018), and details
can be found at https://doi.org/10.1186/s12874-018-0482-1
See https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data
for original data.
Covariate names restored from https://hbiostat.org/data/repo/support2csv.zip
Variables:
age:
age in years.
sex:
patient sex. (1: female, 0: male)
race:
patient race (unfortunately, the original data did not provide which the values represent which races).
comorbidity:
number of comorbidities.
diabetes:
presence of diabetes.
dementia:
presence of dementia.
cancer:
presence of cancer. (2: yes, 1: no, 0: metastatic)
meanbp:
mean arterial blood pressure.
hr:
heart rate.
rr:
respiration rate.
temp:
temperature.
sodium:
serum’s sodium.
wbc:
white blood cell count.
creatinine:
serum’s creatinine.
duration: (duration)
the right-censored event-times.
event: (event)
death indicator {1: death, 0: censoring}.
"""
name = 'support'
covariates = ['age', 'sex', 'race', 'comorbidity', 'diabetes', 'dementia', 'cancer', 'meanbp', 'hr', 'rr', 'temp', 'sodium', 'wbc', 'creatinine']
categorical_covariates = ['sex', 'race', 'diabetes', 'dementia', 'cancer']
[docs]
class GBSG(_DeepSurvDataset):
"""
Rotterdam & German Breast Cancer Study Group (GBSG)
A combination of the Rotterdam tumor bank and the German Breast Cancer Study Group.
This is the processed data set used in the DeepSurv paper (Katzman et al. 2018), and details
can be found at https://doi.org/10.1186/s12874-018-0482-1
See https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data
for original data.
Covariate names restored from https://www.kaggle.com/datasets/utkarshx27/breast-cancer-dataset-used-royston-and-altman
Variables:
hormon
hormonal therapy, 0= no, 1= yes
size
tumor size (0: <20 mm, 1: [20 mm to 50 mm], 2: > 50 mm))
meno
menopausal status (0= premenopausal, 1= postmenopausal)
age
age, years
nodes
number of positive lymph nodes
pgr
progesterone receptors (fmol/l)
er
estrogen receptors (fmol/l)
duration: (duration)
the right-censored event-times.
event: (event)
event indicator {1: event, 0: censoring}.
"""
name = 'gbsg'
covariates = ['hormon', 'size', 'meno', 'age', 'nodes', 'pgr', 'er']
categorical_covariates = ['hormon', 'size', 'meno']