API
Subpackages
Submodules
coxkan.CoxKAN module
Main module for CoxKAN class.
- class coxkan.CoxKAN.CoxKAN(**kwargs)[source]
Bases:
KANCoxKAN class
Attributes:
- act_fun: a list of KANLayer
KANLayers
- depth: int
depth of KAN
- width: list
number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
- grid: int
the number of grid intervals
- k: int
the order of piecewise polynomial
- base_fun: fun
residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
- symbolic_fun: a list of Symbolic_KANLayer
Symbolic_KANLayers
Methods:
- __init__():
initalize a CoxKAN model
- process_data():
preprocess dataset and register metadata
- train():
train the model
- cindex():
compute concordance index
- predict():
predict the log-partial hazard
- predict_partial_hazard():
predict the partial hazard (exp of log-partial hazard)
- prune_edges():
prune edges (activation functions) of the model
- prune_nodes():
prune nodes (neurons) of the model
- fix_symbolic():
set (l,i,j) activation to be symbolic (specified by fun_name)
- plot():
plot the model
- plot_act():
plot a specific activation function
- suggest_symbolic():
find the best symbolic function for a specific activation (highest r2)
- auto_symbolic():
automatic symbolic fitting
- symbolic_formula():
obtain the symbolic formula of the full model
- symbolic_rank_terms():
calculate standard devation of each term in symbolic formula
- auto_symbolic(min_r2=0, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1)[source]
Automatic symbolic regression: using best suggestion from suggest_symbolic to replace activations with symbolic functions. This method is just slightly adapted from the original KAN.auto_symbolic().
Args:
- min_r2float
minimum r2 to accept the symbolic formula
- libNone or a list of function names
the symbolic library
- a_rangetuple
sweeping range of a
- b_rangetuple
sweeping range of b
- verboseint
verbosity
Returns:
- :
bool: True if all activations are successfully replaced by symbolic functions, False otherwise
- cindex(df, duration_col=None, event_col=None)[source]
Compute model’s concordance index on a dataset.
Args:
- dfpd.DataFrame
dataset
- duration_colstr
column name for duration
- event_colstr
column name for event
Returns:
- :
- cindexfloat
concordance index
- fix_symbolic(l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False)[source]
Set (l,i,j) activation to be symbolic (specified by fun_name).
Args:
- lint
layer index
- iint
input neuron index
- jint
output neuron index
- fun_namestr
function name
- fit_params_boolbool
obtaining affine parameters through fitting (True) or setting default values (False)
- a_rangetuple
sweeping range of a
- b_rangetuple
sweeping range of b
- verbosebool
If True, more information is printed.
- randombool
initialize affine parameteres randomly or as [1,0,1,0]
Returns:
- :
None or r2 (coefficient of determination)
- plot(show_vars=False, **kwargs)[source]
Plot the model.
Args:
- show_varsbool
If True, show the registered covariates on the plot. Default: False
**kwargs : Keyword arguments to be passed to KAN.plot()
Keyword Args:
- folderstr
the folder to store pngs
- betafloat
positive number. control the transparency of each activation. transparency = tanh(beta*l1).
- maskbool
If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
- modebool
“supervised” or “unsupervised”. If “supervised”, l1 is measured by absolution value (not subtracting mean); if “unsupervised”, l1 is measured by standard deviation (subtracting mean).
- scalefloat
control the size of the diagram
- in_vars: None or list of str
the name(s) of input variables
- out_vars: None or list of str
the name(s) of output variables
- title: None or str
title
Returns:
- :
- figFigure
the figure
- plot_act(l, i, j)[source]
Plot activation function phi_(l,i,j)
Args:
- lint
layer index
- iint
input neuron index
- jint
output neuron index
- plot_best_suggestion(i, j, lib=None, a_range=(-10, 10), b_range=(-10, 10), verbose=1)[source]
Plot the best symbolic suggestion for activation function phi_(l,i,j)
Args:
- lint
layer index
- iint
input neuron index
- jint
output neuron index
- libNone or a list of function names
the symbolic library
- a_rangetuple
sweeping range of a
- b_rangetuple
sweeping range of b
- verboseint
verbosity
Returns:
- :
- figFigure
the figure
- predict(df)[source]
Predict log-partial hazard for all samples in a dataset.
Args:
- dfpd.DataFrame
dataset
Returns:
- :
- log_phpd.Series
log-partial hazard
- predict_partial_hazard(df)[source]
Predict partial hazard for all samples in a dataset (exp of log-partial hazard).
Args:
- dfpd.DataFrame
dataset
Returns:
- :
- partial_hazardpd.Series
partial hazard
- process_data(df_train, df_test, duration_col, event_col, covariates=None, categorical_covariates=True, normalization='minmax')[source]
- Preprocess dataset and register metadata via the following steps:
Encode categorical covariates via label-encoding (if categorical_covariates is not None)
Normalize covariates
Register metadata: duration_col, event_col, covariates, normalizer, categorical_covariates and category_maps (maps from the encoded values of each category to the original names)
Args:
- df_trainpd.DataFrame
training dataset
- df_testpd.DataFrame
testing dataset
- duration_colstr
column name for duration
- event_colstr
column name for event
- covariateslist
list of covariates. If None, all columns except duration_col and event_col are used.
- categorical_covariatesbool or list
If True, categorical covariates are automatically detected and label encoded. If a list is provided, only the covariates in the list are label encoded.
- normalizationstr
normalization method: ‘minmax’ for \((x - min(x))/(max(x) - min(x))\), ‘standard’ for \((x - mean(x))/std(x)\), or ‘none’
Returns:
- :
- df_trainpd.DataFrame
training dataset with processed covariates
- df_testpd.DataFrame
testing dataset with processed covariates
- prune_edges(threshold=0.02, verbose=True)[source]
Prune edges (activation functions) of the model based on a threshold of the L1 norm of that activation.
Args:
- thresholdfloat
any activation with L1 norm less than this threshold will be pruned
- verbosebool
If True, print pruned activations
Returns:
- :
None
- prune_nodes(threshold=0.01, mode='auto', active_neurons_id=None)[source]
Prune nodes (neurons) of the model based on a threshold of the L1 norm of the incoming and outgoing activations of that neuron. This method is just slightly adapted from the original KAN.prune().
Args:
- thresholdfloat
any neuron which has all incoming and outgoing activations with L1 norm less than this threshold will be pruned
- modestr
“auto” or “manual”. If “auto”, the thresold will be used to automatically prune away nodes. If “manual”, active_neuron_id is needed to specify which neurons are kept (others are thrown away).
- active_neuron_idlist of id lists
For example, [[0,1],[0,2,3]] means keeping the 0/1 neuron in the 1st hidden layer and the 0/2/3 neuron in the 2nd hidden layer. Pruning input and output neurons is not supported yet.
Returns:
- :
- model2CoxKAN
pruned model
- suggest_symbolic(l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True)[source]
Suggest the symbolic candidates of activation function phi_(l,i,j)
Args:
- lint
layer index
- iint
input neuron index
- jint
output neuron index
- libdic
library of symbolic bases. If lib = None, the global default library will be used.
- topkint
display the top k symbolic functions (according to r2)
- verbosebool
If True, more information will be printed.
Returns:
- :
- fun_namestr
suggested symbolic function name
- funfun
suggested symbolic function
- r2float
coefficient of determination of best suggestion
- symbolic_formula(floating_digit=None, var=None, normalizer=None, simplify=False, output_normalizer=None)[source]
Obtain the symbolic formula.
Args:
- floating_digitint
the number of digits to display
- varlist of str
the name of variables (if not provided, by default using [‘x_1’, ‘x_2’, …])
- normalizer[mean array (floats), varaince array (floats)]
the normalization applied to inputs
- simplifybool
If True, simplify the equation at each step (usually quite slow), so set up False by default.
- output_normalizer: [mean array (floats), varaince array (floats)]
the normalization applied to outputs
Returns:
- :
- symbolic formulasympy function
the symbolic formula
- x0list of sympy symbols
the list of input variables
- symbolic_rank_terms(floating_digit=5, z_score_threshold=5, normalizer=None)[source]
Calculate the standard deviation of each term in the symbolic expression of CoxKAN.
Standard deviation can be used as a measure of importance of each term in the symbolic expression. The terms with higher standard deviation are more important. A caveat here is that terms with outliers in their outputs may have higher standard deviation, which may not necessarily mean they are more important. To address this, we remove outliers iteratively based on Z-score until no outliers are left.
Args:
- floating_digitint
the number of digits to display
- z_score_thresholdint
the threshold of Z-score for removing outliers
- normalizer[mean array (floats), varaince array (floats)]
the normalization applied to inputs
Returns:
- :
- terms_stddict
dictionary of terms and their standard deviations
- train(df_train, df_val=None, duration_col='duration', event_col='event', covariates=None, opt='Adam', lr=0.01, steps=100, batch=-1, early_stopping=False, stop_on='cindex', log=1, lamb=0.0, lamb_l1=1.0, lamb_entropy=0.0, lamb_coef=0.0, lamb_coefdiff=0.0, update_grid=True, grid_update_num=10, stop_grid_update_step=50, small_mag_threshold=1e-16, small_reg_factor=1.0, metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu', progress_bar=True)[source]
Train the model.
Args:
- df_trainpd.DataFrame
training dataset
- df_valpd.DataFrame
validation dataset
- duration_colstr
column name for duration
- event_colstr
column name for event
- covariateslist
list of covariates. If None, all columns except duration_col and event_col are used.
- optstr
optimizer. ‘Adam’ or ‘LBFGS’
- lrfloat
learning rate
- stepsint
number of steps
- batchint
batch size. If -1, use all samples.
- logint
log frequency
- lambfloat
overall regularization strength
- lamb_l1float
l1 regularization strength
- lamb_entropyfloat
entropy regularization strength
- lamb_coeffloat
spline coefficient regularization strength
- lamb_coefdifffloat
spline coefficient difference regularization strength
- update_gridbool
If True, update grid regularly before stop_grid_update_step
- grid_update_numint
the number of grid updates before stop_grid_update_step
- stop_grid_update_stepint
no grid updates after this training step
- small_mag_thresholdfloat
threshold to determine large or small numbers (may want to apply larger penalty to smaller numbers)
- small_reg_factorfloat
penalty strength applied to small factors relative to large factos
- metricslist
additional metrics to log
- sglr_avoidbool
avoid nan in SGLR
- save_figbool
save figures
- betafloat
beta for plotting
- save_fig_freqint
save figure frequency
- img_folderstr
folder to save figures
- devicestr
device to use (no need to change as gpu is typically slower)
Returns:
- :
- logdict
log[‘train_loss’], 1D array of training losses (Cox loss) log[‘val_loss’], 1D array of val losses (Cox loss) log[‘train_cindex’], 1D array of training concordance index log[‘val_cindex’], 1D array of val concordance index log[‘reg’], 1D array of regularization (regularization in the total loss)
coxkan.utils module
CoxKAN Utility Functions
- coxkan.utils.FastCoxLoss(log_h, labels, eps=1e-07)[source]
Simple and fast implementation of Cox Proportional Hazards Loss. Credit: https://github.com/havakv/pycox
We just compute a cumulative sum. In the case of ties, this may not be the exact true Risk set. This is a limitation, but fast.
- Return type:
Tensor
Args:
- log_hTensor
Log-partial hazard.
- labelsTensor
Labels tensor: first column is time, second column is event indicator.
- epsfloat
Small value to prevent log(0).
- class coxkan.utils.Logger(early_stopping=False, stop_on='cindex')[source]
Bases:
objectLogger class to store training and testing metrics.
- coxkan.utils.add_symbolic(name, fn, sympy_fn)[source]
Add a symbolic function to the symbolic library.
Args:
- namestr
Name of the symbolic function.
- fncallable
Function (lambda or torch)
- sympy_fncallable
Sympy function
Returns:
- :
None
- coxkan.utils.bootstrap_metric(metric_fn, df, N=100)[source]
Bootstrap the confidence interval of a metric.
Args:
- metric_fncallable
Metric function that takes a DataFrame as input.
- dfpd.DataFrame
DataFrame to bootstrap.
- Nint
Number of bootstrap samples. The default is 100.
Returns:
- :
- resultsdict
results[‘full’], metric of the full dataset. results[‘mean’], mean of the bootstrap samples. results[‘confidence_interval’], 95% confidence interval of the metric. results[‘formatted’], formatted string of the metric and confidence interval.
- coxkan.utils.categorical_fun(inputs, outputs, category_map)[source]
Create a categorical (discrete) function.
Primary purpose is for creating symbolic activation functions for categorical covariates. The function accepts an array of inputs and outputs of a given (non-symbolic) activation function, as well as a dictionary mapping the encoded values of a categorical variable to the original category name. It returns a function that maps the inputs to the outputs (exactly as the activation function did), as well as a discrete Sympy function that represents the categorical mapping.
Args:
- inputstorch.Tensor
Inputs to the function.
- outputstorch.Tensor
Outputs of the function.
- category_mapdict
Dictionary mapping encoded values of a category to the original category name.
Returns:
- :
- funccallable
Function that maps inputs to outputs based on the categorical mapping.
- sympy_funccallable
Discrete Sympy function representing the categorical mapping.