Source code for coxkan.utils

"""
CoxKAN Utility Functions
"""

import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch
from torch import Tensor, nn
import sympy
from lifelines.utils import concordance_index
import warnings
import scipy.stats as st
from kan.utils import fit_params, SYMBOLIC_LIB

# remove 'arcsin' from the symbolic library
del SYMBOLIC_LIB['arcsin']

[docs] def bootstrap_metric(metric_fn, df, N=100): """ Bootstrap the confidence interval of a metric. Args: ----- metric_fn : callable Metric function that takes a DataFrame as input. df : pd.DataFrame DataFrame to bootstrap. N : int Number of bootstrap samples. The default is 100. Returns: -------- results : dict results['full'], metric of the full dataset. results['mean'], mean of the bootstrap samples. results['confidence_interval'], 95% confidence interval of the metric. results['formatted'], formatted string of the metric and confidence interval. """ metrics = [] size = len(df) for _ in range(N): resample_idx = np.random.choice(size, size=size, replace=True) df_ = df.iloc[resample_idx] df_ = df_.reset_index(drop=True) metric = metric_fn(df_) metrics.append(metric) mean = np.mean(metrics) conf_interval = st.t.interval(0.95, len(metrics)-1, loc=mean, scale=st.sem(metrics)) return { 'full': metric_fn(df), 'mean': mean, 'confidence_interval': conf_interval, 'formatted': f"{metric_fn(df):.6f} ({conf_interval[0]:.3f}, {conf_interval[1]:.3f})" }
[docs] class Logger: """ Logger class to store training and testing metrics. """ def __init__(self, early_stopping=False, stop_on='cindex'): """ Args: ----- early_stopping : bool Whether to use early stopping. stop_on : str Metric to use for early stopping. Either 'cindex' or 'loss'. """ self.data = {} self.early_stopping = early_stopping self.stop_on = stop_on def __getitem__(self, key): return self.data[key] def __setitem__(self, key, value): self.data[key] = value
[docs] def plot(self): if not self.data: print("No data to plot.") return fig, ax = plt.subplots(1, 2, figsize=(10, 5)) if 'train_loss' in self.data: ax[0].plot(self.data['train_loss'], label='train_loss') ax[0].set_title('Loss') if 'train_cindex' in self.data: ax[1].plot(self.data['train_cindex'], label='train_cindex') ax[1].set_title('C-Index') if 'val_loss' in self.data: ax[0].plot(self.data['val_loss'], label='val_loss') if 'val_cindex' in self.data: ax[1].plot(self.data['val_cindex'], label='val_cindex') # put vertical line at highest val_cindex if 'val_cindex' in self.data and self.early_stopping: if self.stop_on == 'cindex': best_epoch = np.argmax(self.data['val_cindex']) elif self.stop_on == 'loss': best_epoch = np.argmin(self.data['val_loss']) ax[0].axvline(best_epoch, color='k', linestyle='--', label='best_model') ax[1].axvline(best_epoch, color='k', linestyle='--', label='best_model') ax[0].grid(True); ax[1].grid(True) ax[0].legend(); ax[1].legend() return fig
[docs] def FastCoxLoss(log_h: Tensor, labels: Tensor, eps=1e-7) -> Tensor: """ Simple and fast implementation of Cox Proportional Hazards Loss. Credit: https://github.com/havakv/pycox We just compute a cumulative sum. In the case of ties, this may not be the exact true Risk set. This is a limitation, but fast. Args: ----- log_h : Tensor Log-partial hazard. labels : Tensor Labels tensor: first column is time, second column is event indicator. eps : float Small value to prevent log(0). """ # Sort by time durations, events = labels[:, 0], labels[:, 1] idx = durations.sort(descending=True)[1] events = events[idx] log_h = log_h[idx] # Compute the risk set if events.dtype is torch.bool: events = events.float() events = events.view(-1) log_h = log_h.view(-1) gamma = log_h.max() log_cumsum_h = log_h.sub(gamma).exp().cumsum(0).add(eps).log().add(gamma) # Compute loss return - log_h.sub(log_cumsum_h).mul(events).sum().div(events.sum())
[docs] def set_seed(seed): """ Set seed for reproducibility. """ torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.backends.cudnn.deterministic=True torch.backends.cudnn.benchmark = False return seed
[docs] def add_symbolic(name, fn, sympy_fn): """ Add a symbolic function to the symbolic library. Args: ----- name : str Name of the symbolic function. fn : callable Function (lambda or torch) sympy_fn : callable Sympy function Returns: -------- None """ globals()[name] = sympy_fn SYMBOLIC_LIB[name] = (fn, globals()[name])
[docs] def count_parameters(model): """ Count the number of trainable parameters in a model. """ return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs] def categorical_fun(inputs, outputs, category_map): """ Create a categorical (discrete) function. Primary purpose is for creating symbolic activation functions for categorical covariates. The function accepts an array of inputs and outputs of a given (non-symbolic) activation function, as well as a dictionary mapping the encoded values of a categorical variable to the original category name. It returns a function that maps the inputs to the outputs (exactly as the activation function did), as well as a discrete Sympy function that represents the categorical mapping. Args: ----- inputs : torch.Tensor Inputs to the function. outputs : torch.Tensor Outputs of the function. category_map : dict Dictionary mapping encoded values of a category to the original category name. Returns: -------- func : callable Function that maps inputs to outputs based on the categorical mapping. sympy_func : callable Discrete Sympy function representing the categorical mapping. """ # Check that all inputs are in the category map unique_inputs = [round(i.item(), 3) for i in inputs.unique()] for inpt in unique_inputs: assert inpt in category_map.keys() # Create a mapping from input to output mapping = {} for idx, x in enumerate(inputs): x = round(x.item(), 3) if x not in mapping: mapping[x] = outputs[idx].item() else: assert round(mapping[x],6) == round(outputs[idx].item(), 6) # Create the function that maps inputs to outputs def func(x): shape = x.shape x = x.flatten() try: return torch.tensor([mapping[round(x_i.item(), 3)] for x_i in x]).reshape(shape) except: print(f'\n\n\n') print(mapping) raise ValueError(f"Input not found in categorical mapping.") # Create the discrete Sympy function def sympy_func(x): conditions = [] for i in unique_inputs: out = mapping[i] value = category_map[i] if isinstance(category_map[i], float) else sympy.symbols(str(category_map[i])) conditions.append((out, sympy.Eq(x, value))) return sympy.Piecewise(*conditions, (sympy.nan, True), evaluate=False) return func, sympy_func