API

Subpackages

Submodules

coxkan.CoxKAN module

Main module for CoxKAN class.

class coxkan.CoxKAN.CoxKAN(**kwargs)[source]

Bases: KAN

CoxKAN class

Attributes:

act_fun: a list of KANLayer

KANLayers

depth: int

depth of KAN

width: list

number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.

grid: int

the number of grid intervals

k: int

the order of piecewise polynomial

base_fun: fun

residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)

symbolic_fun: a list of Symbolic_KANLayer

Symbolic_KANLayers

Methods:

__init__():

initalize a CoxKAN model

process_data():

preprocess dataset and register metadata

train():

train the model

cindex():

compute concordance index

predict():

predict the log-partial hazard

predict_partial_hazard():

predict the partial hazard (exp of log-partial hazard)

prune_edges():

prune edges (activation functions) of the model

prune_nodes():

prune nodes (neurons) of the model

fix_symbolic():

set (l,i,j) activation to be symbolic (specified by fun_name)

plot():

plot the model

plot_act():

plot a specific activation function

suggest_symbolic():

find the best symbolic function for a specific activation (highest r2)

auto_symbolic():

automatic symbolic fitting

symbolic_formula():

obtain the symbolic formula of the full model

symbolic_rank_terms():

calculate standard devation of each term in symbolic formula

auto_symbolic(min_r2=0, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1)[source]

Automatic symbolic regression: using best suggestion from suggest_symbolic to replace activations with symbolic functions. This method is just slightly adapted from the original KAN.auto_symbolic().

Args:

min_r2float

minimum r2 to accept the symbolic formula

libNone or a list of function names

the symbolic library

a_rangetuple

sweeping range of a

b_rangetuple

sweeping range of b

verboseint

verbosity

Returns:

:

bool: True if all activations are successfully replaced by symbolic functions, False otherwise

cindex(df, duration_col=None, event_col=None)[source]

Compute model’s concordance index on a dataset.

Args:

dfpd.DataFrame

dataset

duration_colstr

column name for duration

event_colstr

column name for event

Returns:

:
cindexfloat

concordance index

fix_symbolic(l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False)[source]

Set (l,i,j) activation to be symbolic (specified by fun_name).

Args:

lint

layer index

iint

input neuron index

jint

output neuron index

fun_namestr

function name

fit_params_boolbool

obtaining affine parameters through fitting (True) or setting default values (False)

a_rangetuple

sweeping range of a

b_rangetuple

sweeping range of b

verbosebool

If True, more information is printed.

randombool

initialize affine parameteres randomly or as [1,0,1,0]

Returns:

:

None or r2 (coefficient of determination)

load_ckpt(ckpt_path, verbose=True)[source]

Load model from checkpoint

plot(show_vars=False, **kwargs)[source]

Plot the model.

Args:

show_varsbool

If True, show the registered covariates on the plot. Default: False

**kwargs : Keyword arguments to be passed to KAN.plot()

Keyword Args:

folderstr

the folder to store pngs

betafloat

positive number. control the transparency of each activation. transparency = tanh(beta*l1).

maskbool

If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.

modebool

“supervised” or “unsupervised”. If “supervised”, l1 is measured by absolution value (not subtracting mean); if “unsupervised”, l1 is measured by standard deviation (subtracting mean).

scalefloat

control the size of the diagram

in_vars: None or list of str

the name(s) of input variables

out_vars: None or list of str

the name(s) of output variables

title: None or str

title

Returns:

:
figFigure

the figure

plot_act(l, i, j)[source]

Plot activation function phi_(l,i,j)

Args:

lint

layer index

iint

input neuron index

jint

output neuron index

plot_best_suggestion(i, j, lib=None, a_range=(-10, 10), b_range=(-10, 10), verbose=1)[source]

Plot the best symbolic suggestion for activation function phi_(l,i,j)

Args:

lint

layer index

iint

input neuron index

jint

output neuron index

libNone or a list of function names

the symbolic library

a_rangetuple

sweeping range of a

b_rangetuple

sweeping range of b

verboseint

verbosity

Returns:

:
figFigure

the figure

predict(df)[source]

Predict log-partial hazard for all samples in a dataset.

Args:

dfpd.DataFrame

dataset

Returns:

:
log_phpd.Series

log-partial hazard

predict_partial_hazard(df)[source]

Predict partial hazard for all samples in a dataset (exp of log-partial hazard).

Args:

dfpd.DataFrame

dataset

Returns:

:
partial_hazardpd.Series

partial hazard

process_data(df_train, df_test, duration_col, event_col, covariates=None, categorical_covariates=True, normalization='minmax')[source]
Preprocess dataset and register metadata via the following steps:
  • Encode categorical covariates via label-encoding (if categorical_covariates is not None)

  • Normalize covariates

  • Register metadata: duration_col, event_col, covariates, normalizer, categorical_covariates and category_maps (maps from the encoded values of each category to the original names)

Args:

df_trainpd.DataFrame

training dataset

df_testpd.DataFrame

testing dataset

duration_colstr

column name for duration

event_colstr

column name for event

covariateslist

list of covariates. If None, all columns except duration_col and event_col are used.

categorical_covariatesbool or list

If True, categorical covariates are automatically detected and label encoded. If a list is provided, only the covariates in the list are label encoded.

normalizationstr

normalization method: ‘minmax’ for \((x - min(x))/(max(x) - min(x))\), ‘standard’ for \((x - mean(x))/std(x)\), or ‘none’

Returns:

:
df_trainpd.DataFrame

training dataset with processed covariates

df_testpd.DataFrame

testing dataset with processed covariates

prune_edges(threshold=0.02, verbose=True)[source]

Prune edges (activation functions) of the model based on a threshold of the L1 norm of that activation.

Args:

thresholdfloat

any activation with L1 norm less than this threshold will be pruned

verbosebool

If True, print pruned activations

Returns:

:

None

prune_nodes(threshold=0.01, mode='auto', active_neurons_id=None)[source]

Prune nodes (neurons) of the model based on a threshold of the L1 norm of the incoming and outgoing activations of that neuron. This method is just slightly adapted from the original KAN.prune().

Args:

thresholdfloat

any neuron which has all incoming and outgoing activations with L1 norm less than this threshold will be pruned

modestr

“auto” or “manual”. If “auto”, the thresold will be used to automatically prune away nodes. If “manual”, active_neuron_id is needed to specify which neurons are kept (others are thrown away).

active_neuron_idlist of id lists

For example, [[0,1],[0,2,3]] means keeping the 0/1 neuron in the 1st hidden layer and the 0/2/3 neuron in the 2nd hidden layer. Pruning input and output neurons is not supported yet.

Returns:

:
model2CoxKAN

pruned model

save_ckpt(save_path='ckpt.pt', verbose=True)[source]

Save the current model as checkpoint

suggest_symbolic(l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True)[source]

Suggest the symbolic candidates of activation function phi_(l,i,j)

Args:

lint

layer index

iint

input neuron index

jint

output neuron index

libdic

library of symbolic bases. If lib = None, the global default library will be used.

topkint

display the top k symbolic functions (according to r2)

verbosebool

If True, more information will be printed.

Returns:

:
fun_namestr

suggested symbolic function name

funfun

suggested symbolic function

r2float

coefficient of determination of best suggestion

symbolic_formula(floating_digit=None, var=None, normalizer=None, simplify=False, output_normalizer=None)[source]

Obtain the symbolic formula.

Args:

floating_digitint

the number of digits to display

varlist of str

the name of variables (if not provided, by default using [‘x_1’, ‘x_2’, …])

normalizer[mean array (floats), varaince array (floats)]

the normalization applied to inputs

simplifybool

If True, simplify the equation at each step (usually quite slow), so set up False by default.

output_normalizer: [mean array (floats), varaince array (floats)]

the normalization applied to outputs

Returns:

:
symbolic formulasympy function

the symbolic formula

x0list of sympy symbols

the list of input variables

symbolic_rank_terms(floating_digit=5, z_score_threshold=5, normalizer=None)[source]

Calculate the standard deviation of each term in the symbolic expression of CoxKAN.

Standard deviation can be used as a measure of importance of each term in the symbolic expression. The terms with higher standard deviation are more important. A caveat here is that terms with outliers in their outputs may have higher standard deviation, which may not necessarily mean they are more important. To address this, we remove outliers iteratively based on Z-score until no outliers are left.

Args:

floating_digitint

the number of digits to display

z_score_thresholdint

the threshold of Z-score for removing outliers

normalizer[mean array (floats), varaince array (floats)]

the normalization applied to inputs

Returns:

:
terms_stddict

dictionary of terms and their standard deviations

train(df_train, df_val=None, duration_col='duration', event_col='event', covariates=None, opt='Adam', lr=0.01, steps=100, batch=-1, early_stopping=False, stop_on='cindex', log=1, lamb=0.0, lamb_l1=1.0, lamb_entropy=0.0, lamb_coef=0.0, lamb_coefdiff=0.0, update_grid=True, grid_update_num=10, stop_grid_update_step=50, small_mag_threshold=1e-16, small_reg_factor=1.0, metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu', progress_bar=True)[source]

Train the model.

Args:

df_trainpd.DataFrame

training dataset

df_valpd.DataFrame

validation dataset

duration_colstr

column name for duration

event_colstr

column name for event

covariateslist

list of covariates. If None, all columns except duration_col and event_col are used.

optstr

optimizer. ‘Adam’ or ‘LBFGS’

lrfloat

learning rate

stepsint

number of steps

batchint

batch size. If -1, use all samples.

logint

log frequency

lambfloat

overall regularization strength

lamb_l1float

l1 regularization strength

lamb_entropyfloat

entropy regularization strength

lamb_coeffloat

spline coefficient regularization strength

lamb_coefdifffloat

spline coefficient difference regularization strength

update_gridbool

If True, update grid regularly before stop_grid_update_step

grid_update_numint

the number of grid updates before stop_grid_update_step

stop_grid_update_stepint

no grid updates after this training step

small_mag_thresholdfloat

threshold to determine large or small numbers (may want to apply larger penalty to smaller numbers)

small_reg_factorfloat

penalty strength applied to small factors relative to large factos

metricslist

additional metrics to log

sglr_avoidbool

avoid nan in SGLR

save_figbool

save figures

betafloat

beta for plotting

save_fig_freqint

save figure frequency

img_folderstr

folder to save figures

devicestr

device to use (no need to change as gpu is typically slower)

Returns:

:
logdict

log[‘train_loss’], 1D array of training losses (Cox loss) log[‘val_loss’], 1D array of val losses (Cox loss) log[‘train_cindex’], 1D array of training concordance index log[‘val_cindex’], 1D array of val concordance index log[‘reg’], 1D array of regularization (regularization in the total loss)

coxkan.utils module

CoxKAN Utility Functions

coxkan.utils.FastCoxLoss(log_h, labels, eps=1e-07)[source]

Simple and fast implementation of Cox Proportional Hazards Loss. Credit: https://github.com/havakv/pycox

We just compute a cumulative sum. In the case of ties, this may not be the exact true Risk set. This is a limitation, but fast.

Return type:

Tensor

Args:

log_hTensor

Log-partial hazard.

labelsTensor

Labels tensor: first column is time, second column is event indicator.

epsfloat

Small value to prevent log(0).

class coxkan.utils.Logger(early_stopping=False, stop_on='cindex')[source]

Bases: object

Logger class to store training and testing metrics.

plot()[source]
coxkan.utils.add_symbolic(name, fn, sympy_fn)[source]

Add a symbolic function to the symbolic library.

Args:

namestr

Name of the symbolic function.

fncallable

Function (lambda or torch)

sympy_fncallable

Sympy function

Returns:

:

None

coxkan.utils.bootstrap_metric(metric_fn, df, N=100)[source]

Bootstrap the confidence interval of a metric.

Args:

metric_fncallable

Metric function that takes a DataFrame as input.

dfpd.DataFrame

DataFrame to bootstrap.

Nint

Number of bootstrap samples. The default is 100.

Returns:

:
resultsdict

results[‘full’], metric of the full dataset. results[‘mean’], mean of the bootstrap samples. results[‘confidence_interval’], 95% confidence interval of the metric. results[‘formatted’], formatted string of the metric and confidence interval.

coxkan.utils.categorical_fun(inputs, outputs, category_map)[source]

Create a categorical (discrete) function.

Primary purpose is for creating symbolic activation functions for categorical covariates. The function accepts an array of inputs and outputs of a given (non-symbolic) activation function, as well as a dictionary mapping the encoded values of a categorical variable to the original category name. It returns a function that maps the inputs to the outputs (exactly as the activation function did), as well as a discrete Sympy function that represents the categorical mapping.

Args:

inputstorch.Tensor

Inputs to the function.

outputstorch.Tensor

Outputs of the function.

category_mapdict

Dictionary mapping encoded values of a category to the original category name.

Returns:

:
funccallable

Function that maps inputs to outputs based on the categorical mapping.

sympy_funccallable

Discrete Sympy function representing the categorical mapping.

coxkan.utils.count_parameters(model)[source]

Count the number of trainable parameters in a model.

coxkan.utils.set_seed(seed)[source]

Set seed for reproducibility.

Module contents