{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CoxKAN Introductory Demo" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from coxkan import CoxKAN\n", "from sklearn.model_selection import train_test_split\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Synthetic Dataset Example\n", "\n", "The code below generates a synthetic survival dataset under the hazard function\n", "\n", "$$\\text{Hazard, } h(t, \\mathbf{x}) = 0.01 e^{\\theta(\\mathbf{x})},$$\n", "\n", "where\n", "\n", "$$\\text{Log-Partial Hazard, }\\theta (\\mathbf{x}) = \\tanh (5x_1) + \\sin (2 \\pi x_2)$$\n", "\n", "and a **uniform censoring distribution**." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Concordance index of true expression: 0.7524\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x1x2durationevent
92540.541629-0.70625142.2706691
1561-0.526259-0.49260654.2834881
1670-0.238753-0.326589361.5699031
6087-0.5880240.74202957.3352780
6669-0.739364-0.30290795.9756681
\n", "
" ], "text/plain": [ " x1 x2 duration event\n", "9254 0.541629 -0.706251 42.270669 1\n", "1561 -0.526259 -0.492606 54.283488 1\n", "1670 -0.238753 -0.326589 361.569903 1\n", "6087 -0.588024 0.742029 57.335278 0\n", "6669 -0.739364 -0.302907 95.975668 1" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from coxkan.datasets import create_dataset\n", "\n", "log_partial_hazard = lambda x1, x2: np.tanh(5*x1) + np.sin(2*np.pi*x2)\n", "\n", "df = create_dataset(log_partial_hazard, baseline_hazard=0.01, n_samples=10000, seed=42)\n", "df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)\n", "\n", "df_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train CoxKAN:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "train loss: 2.77e+00 | val loss: 2.50e+00: 100%|██████████████████| 100/100 [00:06<00:00, 15.48it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "CoxKAN C-Index: 0.7553786667818724\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhvElEQVR4nO3deWyb5eEH8O/rI3HuxKlJkzhpmrPpkYQ20AJVKQWaoQIaKxMMTWMakzZtoxKT4B/+3h/7r0GTYIcYh4aEtBY21kqlVCBWaCk9kqZHGueoE+e+nDi249ivn98fmZ9f3h5p0ryOj34/UiReJ3aeFD/+vs+tCCEEiIiIdGSIdQGIiCj5MFyIiEh3DBciItIdw4WIiHTHcCEiIt0xXIiISHcMFyIi0h3DhYiIdMdwISIi3TFciIhIdwwXIiLSHcOFiIh0x3AhIiLdMVyIiEh3plgXgCgRCCEwPj6OmZkZZGZmIj8/H4qixLpYRHGLLReiRbjdbjQ3N6Oqqgo2mw3r16+HzWZDVVUVmpub4Xa7Y11Eorik8LAwols7duwY9u/fD5/PB2C+9RIRabWkp6fj0KFDaGpqikkZieIVw4XoFo4dO4Z9+/ZBCIFwOHzbnzMYDFAUBUeOHGHAEC3AcCG6gdvtht1uh9/vXzRYIgwGA9LS0uByuZCbmxv9AhIlAI65EN3g/fffh8/nW1KwAEA4HIbP58MHH3wQ5ZIRJQ62XIgWEEKgqqoK3d3dWE7VUBQF5eXlcDgcnEVGBIYLkcbY2BhsNtuKnp+fn69jiYgSE7vFiP5nenoaX3755Ypfg4i4iJLuUaqqorOzE62trfKru7sbqqqu6HXffPNNbNq0CVVVVfIrMzNTp1ITJQ52i9E9YXJyEhcvXkRLSwtaWlrQ1tYGn88Hg8GA6upqNDQ0oL6+HvX19Xj88cfR09OzrDEXACgpKUFzczMcDgccDgc8Hg8AoLCwENXV1TJsSkpKYDCw04CSG8OFko6qqujo6EBraytaWlrQ2toKp9MJALBaraivr5dhsnnzZqSnp2ue39zcjNdee23ZA/oHDx7EgQMHAMxPDBgeHpZB43A44HQ6EQ6HkZqaisrKSk3rJjs7W79/AKI4wHChhDcxMSFDJNIqmZ2dhdFoRG1trWyR1NfXw26333E2V7TWuQQCAXR3d8uw6ejowNTUFACgoKBA07opLS2F0Whc1r8DUTxhuFBCCYVCaG9vl+MkLS0tcLlcAIA1a9agoaFBtko2bdoEi8VyV79nuSv0jx49ir179y7rdwghMDY2ho6ODhk4169fh6qqSElJQUVFhaZ1wwWalEgYLhTXRkdHNa2Sy5cvIxAIwGQyYePGjZoursLCQl3XmCx1b7HDhw8vO1huZ25uDj09PZrWzeTkJADAZrNpWjfr1q2DycQ5ORSfGC4UN+bm5tDe3q4Jk8HBQQDz3UaREGloaEBtbS1SU1OjXia3240PPvgAb731Frq6uuTjFRUVOHDgAF5++WXk5ORE7fcLITAxMaFp3fT09CAUCsFsNqO8vFzTurFarVErC9FyMFwoZoaGhjSD7leuXMHc3BxSUlKwadMmzQyugoKCmJY18iHv8XiQlZUFq9Uas5X4wWAQ169f10wWGBsbAwDk5+ejqqoK1dXVqKysxPr162E2m2NSTrq3MVxoVQQCAVy5ckUGSWtrK4aHhwEARUVFmrGSDRs28ANxmSYnJ2U3WmdnJ7q6uhAMBmEymVBWVqbpTuNBZ7QaGC6kOyEEBgYGNK2Sq1evIhQKwWKx3NQqWcl2K3RroVAIvb29mu60kZERAEBeXp6mdVNeXo6UlJQYl5iSDcOFVmx2dhaXLl3ShEmkm6akpEQz6F5TU8NB6Bhxu93o7OzUtG4CgQCMRiPWrVunad3YbDa2bmhFGC60LEIIuFwuzaD7tWvXoKoq0tLSUFdXp1lXwgHm+KWqKvr6+jStm6GhIQBATk6OZqJARUXFqkygoOTBcKFF+Xw+tLW1afbgmpiYAACUlZVpZnBVVlZy4V+C83g8mmnQXV1dmJ2dhcFgQGlpqaZ1U1BQwNYN3RbDhSQhBJxOp6ZV4nA4EA6HkZGRgbq6OhkmdXV1XNR3DwiHw3C5XDJsHA4HBgYGAABZWVkyaKqrq1FeXo60tLQYl5jiBcPlHjYzM4O2tjbNDK7IdiQVFRWasZLy8nK2SgjA/Pumq6tLhk1nZyd8Ph8URUFJSYkMm6qqKt0XtlLiYLjcI8LhMK5fvy53BW5tbUVnZyeEEMjKypJjJA0NDairq0NWVlasi0wJQgiB/v5+TXdaf38/hBDIyMhAZWUlqqurUV1djYqKips2CqXkxHBJUtPT07h48aJmrMTj8UBRFFRVVWlaJWVlZdwCnnTl8/k0rRuHwwGv1wtFUVBcXKxp3RQXF7N1k4QYLklAVVV0dXVppgJ3d3cDAHJzczWtks2bN/PwKlp1QggMDg5qWjd9fX0QQiAtLU22bqqqqlBZWcn3aBJguCQgt9ut2RW4ra0NXq/3poOvGhoaUFpayrtCikt+v18eQRBp4UQOWCsqKtJMFrDb7WxdJxiGS4I4fPgwvv/+e7S2tuL69esAlnbwFVGiWOyANYvFIo8g2LVrF4qKimJdXLoDhkuCcLlcEELAYrHIL5PJxFYJJTVVVTEzM4OZmRl4PB54PB6ebZMgGC4JQgjBIKF7XuTjinUh/jFciIhId9xBcAHm7MrxjjLxsR6sHOsBw+UmX375Jc6cORPrYiScBx54AHv27Il1MUgn586dw+XLl2NdjISzceNGNDY2xroYcYHhcoPz58+js7MTO3bsiHVREsZ3330Hg8HAcEki7e3tcLlc2Lx5c6yLkjAuXboEg8HAcPkfhsstbNu2Da+88kpSNW2FEPD7/RgcHMTIyAhmZ2eRmpoKq9UKm82G7Ozsu5p9JoRAKBTC9PR0lEpOsbJhwwY8++yzCVMPhBDweDxwOp2wWCxYt24dzGbzqpRfCAFVVeH1eqP+uxIFwyXJCSEwNDSETz/9FJ9//jlcLhcCgQDC4TAURYHZbEZ2djbWrVuHrVu3YseOHdiwYQOys7MBsO+YEoMQAt3d3XjnnXfQ398Po9GIxsZG/OIXv+A+eTHCcEliwWAQR48exZ/+9Ce5TfpCQggEAgGMjo5idHQUZ8+exbvvvovi4mI89NBDePLJJ7Fp0ya5FQeDhuLVzMwM/vrXv6Kvrw/A/Eatp0+fRlZWFl5++WXu6B0DDJckFQgE8M477+C9997D3Nyc5nsmkwlGoxHhcBihUEgzOygUCsHpdMLpdOKf//wnSkpKsGvXLjz++OOoqalBRkYGAAYNxQ8hBL799ls4nc6bHv/666/xyCOPoLq6mu/ZVcZwSULBYBB//vOf8e677yIUCgEADAYDNm3ahKefflq2RmZnZzE4OIirV6/iwoUL6OjowPT0tAybUCiEnp4e9PT04KOPPoLdbsfWrVvR2Ngoz1nPycmJ5Z9KhEAggK+//lq+b9PT0xEMBhEMBjE7O4svvvgCVVVVDJdVxnBJMkII/Pvf/8bf//53GSzp6en41a9+hRdffBEZGRmaSrZlyxY8+eSTCIVCGB4extmzZ/HFF1+gpaUFbrdbVthgMCiD5vDhw0hJSUFWVhZ+//vfx+TvJIro6+tDb2+vvH7qqafgcrnw3XffAQBaW1sxNjaG++67L1ZFvCcxXJKIEALXrl1Dc3Oz7ApLT0/Hm2++iWeeeea2/c6RgX273Y7i4mI8/fTTGBoawqlTp3D8+HG0tbXB4/HIoImM1QSDQWRnZ2NwcHDV/kaKD5HZUSMjIxgbG0NGRgaKi4uRmpq6qi0EIQRaWloQDAYBABaLBTt27MD4+DjOnj0LVVXh8XjQ2tqKJ554gq2XVcRwSSKzs7Nobm7G+Pg4gPmxlV//+teLBsuNFEWByWSC3W7H888/jx/+8IcYGBjAmTNn8M033+Dq1asYGxtDIBBAamoq7HY7rl27Fs0/i+KMEAIjIyP4+OOPceHCBczOzsJkMqGsrAwvvvgiamtrV217/FAohEuXLsnrkpISFBQUwGq1oqCgAAMDAxBC4OzZs3jsscdgMvEjb7XwXzpJCCFw4sQJfPvtt/Kx3bt346WXXrrrmTKRFs26detQWlqKH/3oR5iZmcHg4CB6e3sxPDyMwsJCvf4ESgBCCPT29uKtt95Cf3+/fDwYDMLhcODgwYP47W9/i/r6+lVpJUxOTmrKsXHjRpjNZpjNZmzZskXOkuzp6YHb7caaNWuiXiaax9N3ksT09LRmAD8/Px+vvvoqLBaLLq+vKAqMRiNycnKwYcMG7N27Fz/96U/l7DFKfkIITExM4O2339Z8oC+8efF4PHj33XcxMjKyKnuU9fb2yoWLRqMRGzduBDD/fm1oaJAtKI/Hg87OTu6btooYLklACIHjx4/D4XAAmK9YL7zwAsrLy6N698j+63uLqqr4+OOP5WF1AFBTU4M33ngDu3btku+HkZERfPLJJwiHw1EtjxACHR0dMjAyMjJQUlIiy7F+/Xp57osQQtN9RtHHcEkCXq8XH3/8sazMxcXF+PGPf8xjYUk3QgicP38ep06dko+tX78eBw4cQF1dHV5++WVs2LBBfu/06dO4fv16VFsKqqqiq6tLXhcVFcmdJQAgKysLZWVl8trhcCAQCEStPKTFT58EJ4TA6dOn0dHRIR977rnnYLPZYlgqSjZerxeHDh2Ss7IyMjLw85//HFarFYqiID09Hc8//zxSU1MBzE8uOX78eFTDxefzaWYqlpeXa7roDAYDamtr5fXw8LCc7ELRx3BJcKFQCJ988olmrOWZZ55hlxXpRgiBb775Rq4lURQFTzzxhGZhoqIoqKmp0eyifP78eYyNjUWtXCMjI/B4PPL3l5eXa76vKAqqq6vlDLFAIICenh6Ou6wShkuCczqdOHv2rLzes2cPZ3CRrjweD44dOyY/lO+77z40NTXd1O1qNBrx+OOPy9aDx+PBhQsXovJhLoRAX1+fbEmZzWbNeEtEYWGh7CqLjNHQ6mC4JLDIQP7MzAwAIDU1la0W0t2ZM2dk95OiKNi7dy/y8vJu+jlFUbBhwwZ5cyOEwHfffSdb1XpbOLEgJycH+fn5N/1MZJA/oru7O2rlIS2GSwLz+/04ceKEvK6pqcHGjRsZLqSb2dlZnDhxQtNqeeSRR277HktLS8P9998vr3t6ejAyMqJ7uVRV1Wz5UlBQgLS0tJt+zmAwoKKiQl4PDw9jampK9/LQzRguCUoIgfb2ds1smSeffFK3dS1EANDW1qb5EH/00UfvuFnptm3bYDabAczfAF25ckX3rjG/34/R0VF5XVpaesvZkYqioLKyUoah1+vldkWrhOGSwL788ku5h1hmZiYeffRRtlpIN0IIdHZ2yg/tnJwc7Ny5c9H3mKIoKC0t1cxWvHjxou7hMjk5KQfzgflwuZ3i4mLZqgmHw+ju7uag/irg9i8Jyufz4eTJk/J648aNi1YwouVSFAXPP/887r//fpw4cQJ5eXlL2j7FYrGgpqZGbr3S3d0Nj8ej6/EMg4OD8sbKZDKhuLj4tqGXm5sLm80mz3uJhAtvxKKL4ZKgHA6H5nCk3bt3c1M+0p3ZbMaGDRtQU1ODUCi0pIW5iqJgy5Yt+OqrryCEgNvthsvl0i1chBBwuVya81tuNZi/8G8oKSmR9SVy1PetxmhIP+wWS0BCCJw8eVKuNs7IyFh0kJVoJRRFgcFgQEpKypJ/vry8XH54q6qKa9eu6dYVFQmXiLy8PHkU9+3Ks3BQf2JiAm63W5ey0O0xXBLQ3NwcvvnmG3ldVVWlmW5JFGtWqxVr166V1x0dHbrtNRYKhTA0NCSvCwsLF221K4qCdevWyfU3s7Ozmo03KToYLgnI5XJpZont3LlzyXeVRKvBbDZrWgt9fX1y9+KV8nq9mJiYkNeLjbdEFBQUID09HcB8yyfa+54RwyXhCCFw5swZWVFTUlLw0EMPsUuM4k51dbV8X05PT2N4eFiX152cnITP5wMw3yqx2+13fP9nZWVpJiMwXKKP4ZJgVFXVdIkVFxejsrIyhiUiupmiKCgrK5Mt6mAwqNu+XkNDQ3KVvclk0nS/3U5kUD9iYGCAOyRHGcMlwUxOTuLy5cvyurGxkQd2UVzKz8+H1WqV152dnSt+TSGEPLoYmJ8pdqutaG5l/fr18r8nJyc5qB9lDJcEIoTAlStX5LbhBoMBDz/8cIxLRXRrFosFdrtdXvf29sq1KXdLCKEZjM/Ly5NjKYuJLO5cOKjPlfrRxXBJMKdPn4aqqgDmF4dt2bKF4y0Ul26cAjw6Orrifb1CoZBm7KagoGDJ67sW7j8mhIDT6eS4SxQxXBJIIBDAuXPn5HV1dfWSVkwTxYKiKFi/fr1ceOn3+zVTiO+G3+/H5OSkvC4qKlryzVV2drZmsSXDJboYLgmkv79fs8349u3buSqf4lphYaFmX6+VDuq73W7NlOaioqIlP9dkMmm66fr7+1fcTUe3x3BJEEIItLa2yimYZrMZjY2N7BKjuHbjOSsrDZfR0VF5QJjRaERBQcGS60BkBlvExMQEpqen77ostDiGS4IQQuD06dOyYq5du1bTn00Uj8xms6a1MDAwcNetBSEEBgcHZR1IS0vTzEa7k8hKfT276ej2GC4JwuPx4OLFi/K6rq4OWVlZMSwR0Z3d2FoYHx9fUWshstMyMD+GstieYrdSUFAgzzwKh8Po7e3luEuUMFwSRFdXl7zLUhSFq/IpIejZWlBVVfNcm8227G2PcnJyNK0drtSPHoZLAhBC4Pvvv5d9zenp6aivr2e4UEJYu3atLq2FQCAg13gB85MFllsHUlJSNJMAXC6XXO1P+mK4JIBQKIQzZ87I67KysmXNkiGKJb1aCx6PR3P65N3UgRu76cbGxjioHyUMlwQwNjaGjo4Oed3Y2IjU1NQYloho6W5sLfT3999Va2FsbEzuB2YwGLB27dplt1wi3XSR5/n9ft021CQthkuci2z5EtkHyWg0Yvv27bEtFNEyRD7QI8bHxzUtkKUQQmBoaEieCZOSknLXC4iLiopkN52qqujr6+O4SxQwXBLA6dOnZaXKy8tDbW0tx1soYdzYWvB6vRgZGVn26yycKZaVlYXs7Oy7Kk9ubq5ms8uenp67eh1aHMMlzvn9fpw/f15e19bWLmtuP1E8KCwslF25qqrC5XItq7UQablEWK1W2fpYrtTUVE03XV9fn5wsQ/phuMQ5l8sFp9Mpr7dv3y53diVKFHl5ecjJyZHXC7cxWoq5uTmMjo7K67Vr1951PYjseRYxOjq67G46ujOGSxwTQuD8+fPw+/0A5u+4uOULJSKLxaI51Kuvr0/u7r0UXq9Xc/7KSmZLRmaMReqRz+fjSv0oYLjEsXA4jNOnT8vroqIizTRKokQROU8lYnR0VLMB5Z1MTk7KmyxFUe5qjctCNw7qczGl/hgucWxqagqXLl2S1w0NDTx1khLSja2FmZkZjI2NLem5QggMDw/Llo7JZMJ99923ovLk5eVpNtTs7u5muOiM4RKnhBBwOByyn5lbvlCiKy4uhtlsBgAEg0H09/cv+QN94dHGGRkZyM3NXVFZUlJSNBtq9vX1cft9nTFc4tipU6fkYrOsrCyeOkkJzWq1ajZbXThRZTF3e7TxYm48JXNsbExzCBmtHMMlTs3NzWm2fKmurtYMiBIlmvT0dE13Vm9vr1y/tZiVHG18O4qioLy8XM448/v9XEypM4ZLnHK5XOjq6pLXO3bskF0KRInIaDRqBvWHh4flIP1ifD7fXR9tvJjCwkK5Zb8QAp2dnSt+Tfp/DJc4JITAuXPn5Gwas9mM7du3s0uMEt7CbWCmpqYwMTFxx+dMTk7KuqAoCoqLi3WpC9nZ2SgsLJTXnZ2dy5oeTYtjuMShcDiMkydPyiZ6cXExKisrY1wqopVRFAUlJSWyS2tubu6Og/qRmWKRsUeTyaRb97DRaNTUq/7+fi6m1BHDJQ6Nj49rTp3ctm3bsk/cI4pHNptN0xW1lH29FgbQco82vpPq6mrN9Oi+vj7dXvtex3CJM0IIXLx4UR6KZDAYsGvXLnaJUVLIzMzUDOo7nc5FB/WFEJoPfKvVqttar8jam8jMM1VV0d7ezkF9nTBc4owQAl999ZWscGvWrEFdXR3DhZLCjYP6g4ODiw7qh0IhDA4OyuvCwsIVzxRbKC8vD8XFxfL6ypUrHHfRCcMlzkxNTeH777+X1w0NDZqVxESJrry8XP632+3WHF18o5mZGc1MsZKSEl1vtEwmE2pra+W1y+VatDy0dAyXOCKEQEtLi7xTUxQFjz32GAwG/m+i5BA522XhSv3e3t7bdkWNjo7C5/PJ59rtdt1b8Zs3b5brXbxeL65du7akrjEhBLvQFsFPrTgihMAXX3whm+VWqxUPPPAAu8QoqdhsNrlSXwihWc+1UGRlfmSmWEpKimbqsB4iYReZJBDZifxOoRFZF9PZ2XlXRzbfC/TrvKQVm5iY0OyCvHXr1hVv0EcUb9LT01FcXCzXuPT09EBV1VuOpSzcIiY7O1tzgqReMjMzUVtbK/fxa29vh9vtXnRW2tzcHD766CN0d3djy5Yt2LdvH1sxN2DLJU4IIXDq1Cl5/KvBYEBTUxO7xCjpGI1GzbjL0NAQpqenb/o5VVXR29srrwsKCpCWlqZ7eRRFwQMPPCDr2tTUFNra2m4bFkIIXL58GQ6HA4FAAGfPntWcFkvz+MkVJ0KhED777DM5S6ywsBAPPvggu8Qo6SiKgsrKSs36koGBgZt+zufzafYUKy0tjcrNlqIoqKmpwZo1awDMh8fJkydv290VDAZx9OhR+f3MzEw89thjrKs3YLjEASEErl27hgsXLsjHdu/eHZUuAKJ4UFpaKlshqqrC4XDc1FJYePzwjUcT6y0zMxONjY3yuqOj45ZnvAghcOHCBVy9elU+9tBDD3FT2VtguMQBIQQ++eQTOSsmLS0NzzzzDO+EKGnl5eWhoKBAXnd0dGgWUwoh4HQ6EQwGAczvr6f3NOSFFEXBzp075emUgUAAx44du2nNi8fjweHDh+XjWVlZaGpqYl29BYZLHHA6nfj888/l9bZt21BTU8M3LCUts9msOU/F6XRiZmZG8zMLdynOzc2V3VbREDmGua6uTj527tw5XLlyRbZeQqEQPv30U8040J49e1BUVBS1ciUyhkuMqaqKjz76SM6cMZlMePHFF7m9PiW9jRs3yhuoqakpzYd2MBjU7DtWUlKy4gPC7sRoNGLfvn2a1suHH36IoaEhhEIhHD9+HMePH5dhY7fb8dRTT3HSzW1wKnIMCSFw6dIlfPbZZ/Kx+++/Hzt27GCrhZJa5CTI9PR0eL1eqKqKS5cuYfPmzVAUBePj45rB/KqqqqjXichEg0cffRTHjh0DMH/88R//+EfYbDa0t7fLbjqLxYKXXnoJOTk5US1TImPkxtDMzAwOHjwoBy1TU1PxyiuvyDsnomRmtVpRUlIiry9duoRgMCgXVkb2HDMajZrdi6PJYDDgueeeQ1VVlXxsaGgIbW1tMlhMJhP279+PhoYG3gQuguESI8FgEH/5y19w9uxZ+dgTTzzBVgvdM0wmk2aMw+VyYWBgAEIIzTqT3Nxc2O32VSmToijIycnBb37zm1uOe6anp+MnP/kJu8OWgN1iq0wIgbm5Obz33nv48MMP5QwZu92O3/3ud7ru+EoUzxRFQX19Pf71r38hEAjIBYn5+fmaqb4VFRWrep6RoihYu3YtXn/9dfz3v/9FS0sLZmdnUVpait27d2P9+vUMliXgJ9kKLGW7B0VR5M+pqgqn04m//e1vmkVY6enpeOONN6I61ZIoHpWUlKC0tBQOhwMAcOrUKeTn52NsbAzAfP3ZunXrqtcLRVGQmZmJH/zgB2hqakI4HJabW7KOLg3D5TZmZ2cxPj6Oubk5+eX3++Hz+eSX1+uF1+vF7OwsgsGg5ijW1NRUpKeny/GTyclJtLe34+LFi5iampK/x2Kx4LXXXsPu3bv5pqV7jtlsxs6dO9HZ2QkhBAYGBvCPf/xDtuhzcnLkIH8sKIoCRVHYUrkLDJfbOHfuHF5//XUEg0GEw2GoqgohBMLhsG4b1OXk5OC1117Dc889J++KiO4liqLgwQcfxJEjRzAyMgIhhGa9S2Njo67HGtPqYRzfhsFggNfrhd/vRyAQQCgUkgGzUikpKXj44Yfx9ttvY//+/RxnoXtabm4unn322ZtusHJycrj6PYHxU+02LBaLbBIv/DIajTAajTCbzTCbzUhJSYHZbIbJZJJNZ1VVEQwG5Vfk9Ww2GzZv3ow9e/agvr4eqamprDh0z1MUBbt27cLIyAg+//xzBAIB5OTk4Gc/+1lUDgej1cFwuY2ysjL84Q9/QEpKClJTU2GxWJCWlgaLxQKLxQKz2YzU1FSYTCYYjUYYDAZZCYQQUFUVqqrKcRiz2Yy0tDTZSmGFIfp/ZrMZL7zwAh5++GFMTEzAbrfDZrOxniQwhsstXLhwIepbTSSTCxcuaPaJouTQ0dGBo0ePxuR3R476TiTXrl1btfU4iYDhcoOGhgb4/X7NPke0OLvdjoaGhlgXg3RUXV2Nubk5zRYstLiCggJUV1fHuhhxQxE8m1PiP8XKsRsj8bEerBzrAcOFiIiigFORiYhIdwyXBBEOhzE7O6s5rY/oXqOqKmZmZm46IZLiD8MlQbS3t2Pr1q1ob2+PdVGIYqa3txe//OUvOeEmATBciIhIdwwXIiLSHcOFiIh0x3AhIiLdMVyIiEh3DBciItIdw4WIiHTHcCEiIt0xXIiISHcMFyIi0h3DhYiIdMdwISIi3TFcEoAQApOTkwgGg5icnORhTnRPEkJgYmICXq8XExMTrAdxjuESx9xuN5qbm1FVVYWdO3eis7MTO3fuRFVVFZqbm+F2u2NdRKKoW1gPGhsb8Z///AeNjY2sB3GOJ1HGqWPHjmH//v3w+XwAtEfPRo5QTU9Px6FDh9DU1BSTMhJFG+tB4mK4xKFjx45h3759EEIsejiYwWCAoig4cuQIKxYlHdaDxMZwiTNutxt2ux1+v39Jp04aDAakpaXB5XIhNzc3+gUkWgWsB4mPYy5x5v3334fP51vyccbhcBg+nw8ffPBBlEtGtHpYDxIfWy5xRAiBqqoqdHd3L2smjKIoKC8vh8PhkP3QRImK9SA5MFziyNjYGGw224qen5+fr2OJiFYf60FyYLdYHJmZmVnR8z0ej04lIYod1oPkwHCJI5mZmSt6flZWlk4lIYod1oPkwHCJI/n5+aioqFh2f7GiKKioqIDVao1SyYhWD+tBcmC4xBFFUfDqq6/e1XMPHDjAQUxKCqwHyYED+nGG8/uJWA+SAVsucSY3NxeHDh2CoigwGBb/3xNZmXz48GFWKEoqrAeJj+ESh5qamnDkyBGkpaVBUZSbmvmRx9LS0nD06FHs3bs3RiUlih7Wg8TGcIlTTU1NcLlcOHjwIMrLyzXfKy8vx8GDB9Hf388KRUmN9SBxccwlAUTOsfB4PMjKyoLVauWgJd1zWA8SC8OFiIh0x24xIiLSHcOFiIh0x3AhIiLdMVyIiEh3DBciItIdw4WIiHTHcCEiIt0xXIiISHcMFyIi0h3DhYiIdMdwISIi3TFciIhIdwwXIiLS3f8B0j6zxDTvWDoAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ckan = CoxKAN(width=[2,1], grid=5, seed=42)\n", "\n", "_ = ckan.train(\n", " df_train, \n", " df_test, \n", " duration_col='duration', \n", " event_col='event',\n", " opt='Adam',\n", " lr=0.01,\n", " steps=100)\n", "\n", "# evaluate CoxKAN\n", "cindex = ckan.cindex(df_test)\n", "print(\"\\nCoxKAN C-Index: \", cindex)\n", "\n", "# plot CoxKAN\n", "fig = ckan.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Symbolic Fitting:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "train loss: 2.77e+00 | val loss: 2.50e+00: 100%|████████████████████| 10/10 [00:01<00:00, 8.19it/s]\n" ] }, { "data": { "text/latex": [ "$\\displaystyle - 1.0 \\sin{\\left(6.3 x_{2} + 9.4 \\right)} + 1.0 \\tanh{\\left(4.4 x_{1} \\right)}$" ], "text/plain": [ "-1.0*sin(6.3*x2 + 9.4) + 1.0*tanh(4.4*x1)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# auto-symbolic fitting\n", "_ = ckan.auto_symbolic(lib=['x^2', 'sin', 'exp', 'log', 'sqrt', 'tanh'], verbose=False)\n", "\n", "# train affine parameters\n", "_ = ckan.train(df_train, df_test, duration_col='duration', event_col='event', opt='LBFGS', steps=10)\n", "\n", "display(ckan.symbolic_formula(floating_digit=1)[0][0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see CoxKAN approximately recovers the true log-partial hazard:\n", "\n", "$\\hat{\\theta}_{KAN} = \\tanh(4.4 x_1) -\\sin(6.3 x_2 + 9.4) \\approx \\tanh(5 x_1) -\\sin(2 \\pi x_2 + 3 \\pi) = \\tanh(5 x_1)+ \\sin(2 \\pi x_2)$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Real dataset example" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using default train-test split (used in DeepSurv paper).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "train loss: 2.60e+00 | val loss: 2.41e+00: 100%|██████████████████| 100/100 [00:01<00:00, 55.04it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "CoxKAN C-Index: 0.6798424912829145\n" ] }, { "data": { "text/latex": [ "$\\displaystyle \\begin{cases} 0.06 & \\text{for}\\: meno = 0 \\\\0.22 & \\text{for}\\: meno = 1.0 \\\\\\text{NaN} & \\text{otherwise} \\end{cases} + \\begin{cases} 0.23 & \\text{for}\\: hormon = 0 \\\\-0.11 & \\text{for}\\: hormon = 1.0 \\\\\\text{NaN} & \\text{otherwise} \\end{cases} + \\begin{cases} -0.16 & \\text{for}\\: size = 0 \\\\0.09 & \\text{for}\\: size = 1.0 \\\\0.36 & \\text{for}\\: size = 2.0 \\\\\\text{NaN} & \\text{otherwise} \\end{cases} + 0.759 - 1.16 e^{- 0.03 \\left(- nodes - 0.58\\right)^{2}} - 0.32 e^{- 5.59 \\left(0.02 age - 1\\right)^{2}}$" ], "text/plain": [ "Piecewise((0.06, Eq(meno, 0)), (0.22, Eq(meno, 1.0)), (nan, True)) + Piecewise((0.23, Eq(hormon, 0)), (-0.11, Eq(hormon, 1.0)), (nan, True)) + Piecewise((-0.16, Eq(size, 0)), (0.09, Eq(size, 1.0)), (0.36, Eq(size, 2.0)), (nan, True)) + 0.759 - 1.16*exp(-0.03*(-nodes - 0.58)**2) - 0.32*exp(-5.59*(0.02*age - 1)**2)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApBklEQVR4nO2deXRUVZ7Hv5VKDFsTEpZEgaACymIAEVBwASUKOnbPTHePY49td48ijr3ZiivSiKAgsgjaOt06bdvaxx489mI7TaIo4IILm7KILSAQIJCEbGRPparu/PHl+qpCJalUvVfvvcrvc05OYaxU3Vuv3v3e33o9SikFQRAEQTCRFLsHIAiCICQfIi6CIAiC6Yi4CIIgCKYj4iIIgiCYjoiLIAiCYDoiLoIgCILpiLgIgiAIpiPiIgiCIJiOiIsgCIJgOiIugiAIgumIuAiCIAimI+IiCIIgmI6IiyAIgmA6Ii6CIAiC6Yi4CIIgCKYj4iIIgiCYTqrdAxAEt1BRUYGqqipkZmaib9++dg9HEByNWC6C0AEFBQXIz89Hv379MHz4cPTr1w/5+fkoLCy0e2iC4Fg8csyxILTNkiVLMHfuXHi9XgQCga9/r/97yZIleOCBB2wcoSA4ExEXQWiDgoICXHfddVE9b+bMmQkYkSC4B3GLCUIbLF++HF6v9+v/7hbhOV6vFytWrEjcoATBJYjlIggRqKioQL9+/XAGgKkAfgzgWwBOAPg5gPcAlIQ8v7y8XIL8ghCCZIsJQihKAXv3IrBmDV4HcBmAnuCN4gfwDQBPAjgO4CCATQA+AVB97JiIiyCEIJaLIFRVAe+9B6xfD2zYABw+DOX3o7i+Hg0A+gP4AMBsANcDuB1AHYAvAZwNIAPARZdeirSLLgImTuTP8OFAinidha6LiIvQ9fD7gW3bKCQbNgCffgoEAkD//kC3bsA3vgFMnoyCP/wBY48fxx8APATAC2AIaMksBi2ZJz0eXDxmDBbfeCNw4ADw1Vd8/T59gAkTDLERq0boYoi4CF2Dw4dpmWzcSCultpYCkJcHpKcDNTVARgYwbRowZQqwZAka33oLPy4vx4unXiIdFJci0FpZDiAXQLclS3DxLbcAJSXAyZNAcTHfb+9eYP9+utqGDqXITJrE9zzjDBs+BEFIHCIuQnJSXw988IFhnRw4AHi9tCYuvRTo0YOL//HjwODBwLXXAldcAZSVAT/4AR+few5LNm36us4lNRD4Wlz8Xi/SAgGsnzoVkxsagDvuAH70I75vSQlQUUFRSUsDjh0Ddu8Gtmzh79PTgbFjKTSTJgG5uYDHY/MHJgjmIuIiJAfBIBdwHTfZsgVoaeHCfdVVwJVXAoMGUXDee49usEsuoaiMHMnF/YMPgNtuowvrpZeAc88FABQWFmLFihV4/+23vxaXy/PzMWfOHMy85hrgt78Fnn8emDEDmDeP4uH3U6BKS4GmJqBnT2DAAFpM27ZxfDt2cIwDBhhWzfjxQO/etn6UgmAGIi6CeyktpZtrwwY+VlRwEb/8crq3rrqKgrJpE1BYSEulb1/gmmuAq6+mW0zz4ovAL38JXHYZ8JvfRFzgK4qLUff55+g1ejT6DhwY/j/feQdYsICCtHw54zea6mpaM1VVtJ769wdycihoO3dSaDZvBg4dYhLAiBFGrGbUKP6NILgMERfBPTQ3Ax9/TCFZvx7Ys4cL9JgxtEyuugq46CLGM0pLgTff5KJfW0s31LXX8v+HLtYtLcD8+cDvfw/MmsV/p7aRod/UBBQVAUOGMPDfmn/8A7jnHrrDli2jMLQef2kpLZqWFgpYTg6QlcV5nDhhCM22bYwD9ezJMWvLJifHtI9TEKxExEVwLkoB+/bRMlm/HvjwQy7w2dkUkyuvBKZONTKxgkFg+3ZaKZ9+yrjK9Ol0V5155umvX10NzJ5NwVqyBLjppvbH05G4AEB5OXDffbSS5s+nldSaYBCorKQ1U1tLMRwwgPPSgf5gEPjySwrNli3A55/zd4MGUWQmTgQuvBDo3j3qj1MQEomIi+AsqqqA9983MruKi7ngTpliuLpGjAgPgJ88SQvlzTe5+x86lFbKpZcy/hGJffuAH/6Qf/s//wNMntzx2KIRFwDw+YDHHgMKCoBbbqGAtVXz0tBAkTlxgmKamUnrJCMj/Hn19RRObdkcP04L64ILDLEZNkxqawTHIOIi2Ivfz0UztOYkGATOP9+wTiZPPn2HrhR39oWFtGg8HsZaZs7kItse69czu2vgQMZacnOjG2u04qLH9/LLwDPP0LpasICWVFsEAhSYkhKgsZHzzc5mfKa1m04piq4Wmu3bObbMzPDamqys6OYlCBYg4iIkniNHwmtOamoYXJ86lZbJ1Klc+CPR1MS/KSxkADwnh4Jy5ZUsfmwPpYDnngMWLaK77JlngF69oh93Z8RF8/77TBQYOJCB/kjuudbU1FBkKispmv36cZ49e0Z+fksL3WZabPbu5e+HDjWsmjFjmBYtCAlCxEWwnvp6Zmxp6+Srr4yaE22djB3bflbU0aMUlI0bubOfMIGur7Fjo6sR8fmA++8H1qwBfvpT/ruzWVixiAvA+d59N/9+2TIu9NHg8xnpzD4fxTM7mzGm9txf1dXA1q2G2FRW0j04bpwhNlJbI1iMiItgPsEgd9K65mTz5vCak2nT6MLqqJ4jEODfFhSwhqV3b6YQX3NNeKpvR5SXA7feCnz2GbBiBfDd78Y2r1jFBWAs6YEHOI8HHwSuvz76v1WKf687AKSmMgEgJ6ftmFLo3x44YAjNzp1GbY0Wmosu6tjqE4ROIuIimENZWXjNSXk5YwyXX25YJ+ecE91uubISWLcOeOstLqojR9L1dcklnXft7NnDwL3PB7zwAhfSWIlHXAAu6k88Abz+OjPTfvazzgfgm5qMBAC/n3GW7Gy6FaP5bJuaKDCbN/OnqCi8tmbSJH7eUlsjxImIixAbPh9TeLWr6/PP+fsxY4yK+AkTou+hpRR39YWFwCef8O+mTqWoDBkS2xgLC+kCO/dcBu7POiu219HEKy4A57lmDbBqFRMVFi3qXNxHEwxSwEtK6HZMT6fIDBjQOQEuKwuvramtNWprdHua7OzOj0/o8oi4CNGhFJswajHZtImxjwEDDMtk2rTOd/+tr+frvfkmM6AGD6agTJ3afnZVR2N9+mng8cfpflq1KvbXCsUMcdF8/DEwdy6D9StXsn4lVurqwvuZ9etHQeisqysYZCHoli3htTWDBxtWzbhxUlsjRIWIi9A21dXhNSdHj9KimDzZqDnRfbk6y8GDjKW8/z7dO5dcQlEZNSq+QHNTE4Pnf/0rMGcOcNdd5tV+mCkuAF/r7rsZR1m6ND6XHRC5n5lOZ47lM6irC6+tKSlhvCcvz4jXDB0qtTVCRERcBAO/n3Um2jrZvp071/POM6yTKVNi37m2tLAmpaCA6bJZWayez89n7CBeSkuB//xP7r5Xrwa++c34XzMUs8UFoBvqwQfpkrr3XuDb3zbndVv3M9MdAGK9dkpxc6Gtmta1NVpszLiOQlIg4tLVOXLEEJP33+cuuk8ftp/XmV1t1ZxES2kpg/Nvv83FdMwYphFPmGBe4HjHDra893jYJywvz5zXDcUKcQGYFffkk8CrrwL/9m+0ttrqb9ZZWvczy8igyOh+ZrGia2t0expdWzNsWPi5NVJb02URcelqNDQwXqLThHXNyUUXGdbJuHHxL/rBIK2gggI+du9OsZoxI36xas3rrwO/+AUwejTb31sVgLZKXDR//jPrYMaPZ68zM1vvR+pnphMAzDi4rKrKqK3ZsoXv1a1beG3N4MFSW9OFEHFJdnTNibZOPvmEu87Bg8NrTlr3soqVmhqjz1dZGdOPr72WrezNXpCDQVa9r1oFfOc7/HdHdR/xYLW4AHSP3X8/r8eKFcDZZ5v/HvX1tGZ0P7OsLAqNWd8Bpbhp0UKja2uys43WNFJbk/SIuCQjJ04Y9SYbNhg1J5ddZlgn555r3i5SKbpFCgtpFXk8bBo5cyYwfLg1u9WGBuDnP6dlNHcu8OMfW78rToS4AMyau/tuXsfHHouuqWYsROpnlpPDBAAz61yamui21LU1hw8zCWDkSMOFNmKE1NYkGSIuyYDPR4tEWye7d/P3eXlGzcnEieaf297UxDhNYSGzv7Kz6fa66iprT1MsLmZhZFER+4NFamtvBYkSF4DWxbx5wEcfAXfeCdx4o7Xi2bqfWf/+vJ5t9TOLh9JSw6rZupVZab16GbU1EydKbU0SIOLiRrTbQZ9zomtO+vcPrznp18+a9y8upqBs2GD0+Zo5k/51q62HrVvZxr57dxZGjhxp7fuFkkhxAej2e+YZdlf+1rfoLrM6QB5rP7NY0bU1OjFgzx7+Ljc3vLYmEZ+3YCoiLm6huppnvOtAvK45ueQSo+Yk3hqR9tB9vgoLgV27aJnk59NqGDDAmvdszZo1PIhr/HiewdLZgs14SbS4aP7+d2DxYiYsLF2amHTfSP3MsrP5Y2Vcq66OcSddW1NaSkFtXVsjiQGOR8TFqeiaEx032baNO7rhw40jfSdPNqfyvD10n6916/jvESNopUyenLg000CAi+t//zfwH//BTCo7UlztEheAQfF77+X7rljR8Zk1ZtLYaKQzBwKd72cWK6G1NZs3837QtTU6MUBqaxyLiIuTOHrUcHXpmpOMjPCak3jahESLUswwKygw+nxdcQVFxYrspfaorWWwfsMGHrh166327VrtFBeAVsQ99/B7snAhr0kiidTPLCeHlqtZdTnt0dLCeKJ2oe3bx98PG2b0QbvgAqmtcQgiLnbS0MCKdd1eZd8++rVDa04uvDBxWTQNDRxHQQHjKoMGUVCmTbPeQorEoUMsjCwpAX79a47DTuwWF4BWxIIFvE533MHEBjvEtraW1kx5Od+/b9/Y+pnFQ1WVkRiwZQv/u1s33jOTJjEWKLU1tiHikki0RRBac+Lz8QbQYmJmzUm0HDzIupR336U77uKLKSqjR9t3Y27aBNx2G10vL72UWDdQWzhBXABaEM8/z4LRmTOZVWZ2JmC06H5mJSXsBtCzJ62Zfv0S23MsGAw/t2bXLqO2JvTcmlg6UAsxIeJiNeXl4TUnJ04w0+myywxXlx0BypYWprkWFjJbJyuLwfn8fPvPXn/5ZeChh5is8NxzFBgn4BRx0axbBzzyCONwy5ZZlx0YDUox6aS01Lx+ZvHQ1MTD4bQLTdfWjBoVXlsjTTctQ8TFbHw+fqG1dbJrF3+fl2dYJ5Mm2bfTLCsz+nzV1LDP18yZvOHsLmLz+4GHHwZ+9zs2oFywwFn+c6eJCwB88QXjMAAD/SNG2DsewOhnVlrKa5qRQWsmM9M+S7it2poJEwyxSVTWYxdBxCVe9DGyoTUnDQ3cRYbWnHTmWF4rxvjpp7RStm2zts9XrJw8CcyeTWvqsceAm2+2e0Sn40RxAWgN33sva58WLACmT7d7RMTqfmaxEggY59Zs3kyB1rU1OjFg7FhnXWMXIuISCydPMptLWydHjvBmufhiQ1BGjbLf5K6poeC9+SZ3blb2+YqH/fsZmK6qYv3KlCl2jygyThUXgNbCo4/yWs+axR+7v3+hROpnlpNjbSeHaKmt5REC2oWma2vGjDHSnaW2ptOIuERDIHD6OSeBAIPMuuZkyhR7MqpaoxSzznSfL6XY5+vaa63r8xUPGzYw6yknhxX3iU517gxOFheA1/r3vweefZbfyYcfdt6pkYnqZxYrSnGzqK2azz7jdc/KMs6tmTBBamuiQMSlLYqLDTF57z0GK3v3Dq85GTzY7lEaNDfTmiooYPbXgAF0e02f7ozdYWuUopXyyCMU6GefdX6XXKeLi+bdd4H58/n9XLHCuX26Tp6klZCIfmax0tLCuKm2avbv5+/PO8+I1Ywe7azYoEMQcdE0NNDfr9urhNacTJtm1JwkolisMxQX0xWyfj13guPH00oZN85ZbpFQWlp4+uIrr9BqmTvXGbvWjnCLuABcBOfM4aZj2TJrDk8zi0j9zHJyaC047TtcWRl+bo2urRk/3nChDRrkPA+BDXRdcVGKTfK0dfLxx/xiDxoUXnPilDTYUAIBfrELC9kWpHdvWijXXOPcXaqmooLxgO3buejdcIPdI4oeN4kLwIXv/vtZW/XQQ8B119k9ovZRiot3aSmtmrQ0I53Zyn5msRIMGufW6Noav5/CGFpb4yRLLIF0LXEpL2e9ia45KSujz/fSSw1X17Bhzt11VFWxtuGtt3gTnn8+04inTHGHWf7FF6y4b2wEXniBvms34TZxAbhhWroUeOMNZuD95CfOswYiEamfWU6OMzd7msZGo7Zm61ajtmb0aMOq6UK1NcktLj4fdxXaOtm5k78fPdo45+Tii+1Ni+wIXdVfWMiK/tRUo8/XOefYPbroeestLmxDhjDo7JQU6M7gRnEB+B364x+Bp57iRmTRIvfspoNBJgCUljLjrFs3I53ZaS7q1pSUhNfW1NfT5Rd6bk0S19Ykl7goxWC2jpt88EF4zcm0afxxwwXVfb4KC9mocOBAo8+XWxYGgNfkmWfYyXjmTC5wbhp/KG4VF82HH9I9lp3NQL/bBD5SP7OcHHe0dAkEaLlrF9o//kHhHDLEEJokq61xv7jU1Bg1J+vXM40wLS285mT0aPeYoocOUVDee4+Wl+7zdcEFznXXtUVzM4PKf/4zcNdd/LdbrkMk3C4uAL9fd9/NhXrpUgai3UZLi5EAYGc/s3ioqWHcUYtNWVl4bc2kSeYeRW4D7hOXQIB+zY0bKSbbtvF3Q4eG15y4aXfc0sKEgoIC7mgyMxmcv/pq+/t8xUppKdvjf/45sGoV8M//bPeI4icZxAVgsPzBB1m7dd99wL/+q90jio3W/cxSU5nOnJPjruuja2s2bzZqa5qbee9roZkwwdnxpgi4Q1yKiw0xaV1zoq0TJ9WcRMuJE4xFrFvHnUxeHq2USZPckZrbFjt3MnCvFAsjx461e0TmkCziAjCraeVK4LXXmLF3113u/s41NRkJAE7pZxYrLS28h3S8JrS2RrvQXFBb40xxaWwMrznZu5fm7vjxxpG+Tqw5iQaluDMpLGSQr3t3iuOMGYk5CMxq3ngDuPNOZsX87nfOT43uDMkkLprXXgOWL+eCtXix8wtZOyIYZLp7aamz+pnFQ0VFeG1NdTXXDX1uzcSJjJ85TESdJy4nT1KVm5v5gYXWnCRDy4VlyyicZ5/NYsfLL0+ehWrtWtaw/Mu/cFecLPPSJKO4AFy47r8fOOssZvK5JW7REfX1zNgqL+emLi/PXe7ySOjaGu1C272bltr3v8/zjxyE6eKilELN0aPxvUhzM62SOMz03oMGwWOBkiul0HDiROwvUFvLHUacGS49+vc3fX5KKdQUF8fzArQ64+yx1nvgQEvmFvT743uR5mbg2DEuwjEW9aWkplr2vawrKYn9BXw+/sT5veyVk2PJtQv4fLG/QCDATWuc8UvvGWdYMre60tLYXyAY5PEBaWlx9ZHrlZ1t+txM9yu1tLRg/+uvo09JCQ97ssFUqz5wAGNuvx1pFpjBfr8fxR9+iF45OfG9UFlZbH/3xReo69MH515/PVJN9rm2tLTgqzfeQIYdzSODQWDLFpwcMAB5s2ZZcu18tbXwArH7qpVi+mswSCumkwRaWtDNIuvb7/ej6O238Q07Yo/BILBjB2r79sWI733P9O+lUgqNFRVI1a1hYllT0tNpycSIv6kJPS0QTr/fjyMbN+IbdqSFBwLA0aOo9Xhw3g03mH7dLAla9M/NRe6SJQy633NPws3sw598Yunr9xk5EgPOP9/S9zgNpYBXXwUKC1F2552WvU3/iRMxONGV88Eg6y5eeglHVq+27G1SzzgDaceP091qQwfrlsZGS18/a+xYnDVmjKXvcRrBIPDrXwOvvYZjjzxi2dukpaej25EjFJbc3IRvWptqay177cy8PJw5erRlr98mmzcDf/gDjs+bZ8nLW7PqZ2ez1cTy5fwJBi15my6DUsDrr1NcrrrKdSmJ7aKFZfVq4KabrC1w1W7Wykrr3iMSSvEn2QgEgKefZpzmu9+lVWcVaWnM/jp+nHGvZPw8E4kubu7Rw7Ku6daIS0oKTxO89VaKy7JlIjCxohQD5S+/zMSGWbMclxUSM8Egvx+rVwO33MJTFK22cnv1YveDRC5OwSDrMJJpQQwEuCl45RUe9HbnndZeO22xDBzIIP2RI8n1eSaaigqKtIXriXW5vF4vT8YD+CUEeBRrsmSiJAKleNb9Cy/w9Mg77nB3LUIo2mJ56ikKy8MPJ2ZuWVlMqvD5Etdpt76eGUvJkGoOMDvp8ceZdj5rFjeRibivPR5+hoEAEyvS02npJstmK1EoBTz/PO+3q69m4bYFWFsoIgITO0qxrc1vfsM89p/9LPmERVssiRIWwKh1qKqimyUR1NRwAUyGRdDvBxYuNBqR3nxzYufl8TAV3OdjH8Hu3WMP8ndVfD7gnXeA/HxLawWtr0JsLTBKseWECEzbKMV2ME8/zUO/5sxxZ8FoJLQrLNEWi8bjoZ+5tpaxQasXJZ2enQx1MT4fT7jcuBH4xS+Af/93exZ1j4dHY+zaBXz5JTtAuLVA0g7+8hduEm6/3dLrl5gVXgvMrbeyuE5iMG2jFPulrVwJjBxJS8/hbR6ixm5h0WRl8XOOt+4lGnQw3wlHTccTo2huBh54gMJy3332CYsmJYX3B8Buw7KeRIffz5ZM48ZZXpSeuO2wuMg6RinuxpYt485s7lxnnsAXC04RFsAoNqustL49TXMzH3v2ZM8oO1CKO/zCQtaeXXxx54QhGGRJwdatwLx5wD/9kzPcUGlp7Lf1xRfs9nzOOc4Yl5P5299Yo3XffZZ/Von1tYiLrG2UYmBt8WJmxcyfnxyuFICL07JldPPdeiuFxe5r3rMnYyFWB4R1fYTXa5+4VFYyGSQ1FVizhvfelCnR//3WrTyobuFC9sBzygLu8dAiHDiQzW379HFno8pE4fczkJ+XB5x5puVvl/g7vLWL7IknxKRVip1PFy1ikHnBgrhaOTgKJwqLx8OzP3Q8xCqUMlpz2LXgKcXvVUsL09nPPZfWR7TtVJTiuS/9+ztLWDQ6g6xXL95Ddgm4G/j73/l9v//+hFxHe+5yERgDpZhvvmABYwELF7q/uZ5Gu8K0sMyfb7+waM44g9/DsjJr6yX8fns7DdfXs1HqD35AF+ATT/B3r7wS3d8fPMiakjlznCcsGo8HOP98Pkr8JTLNzcw8HT06YSeQ2nena4GZNYsCs3Rp1/tSKEVzfv58LkCPPuqMwK8ZaIvlqacMi8VJqdTaevH5rAvs6120ndf0r3/l4003cc5nncXU9hde6HjeSrGepUcPYOpUy4caFzr+0tgIHD4sBZahKAX89reMtcybl7BNgr3bSK+XJvusWcCTT3Y9gSkt5cVOT2dHg2Rp6+JEV1gkevfmjWaV9aLjLXZl+ykF/PGPrAvRvdQ8HsY5m5qAd99t/+/Ly4EdO3h/Omlj0Ba9ezOWUFLCeJoIDKms5Lk9M2cmrrYLdosLcLrAdBUX2YkTFBbdKsfKvkyJxC3CAnChzcykmygQMPe1leICZ2d9UlMTv2etCx0HD6YFs3p12wuwUgz8p6bypEo3oFvE9OjBAwYTkWrudJTi+pKWBvz0pwl1bTrjrg8VmK7gIqusBB56iF/+Rx9lsDQZCBWWWbOcLSyarCzecKWl5u90W1oM68gONm3i45VXhv/e42ERZEkJ0NbZS1VVrGm54QZ31Vl5PDwFVWdfdmXrRSlap9u3s/dbgpOEnHPndxUXWVUVLZbGRgbvzzzTuYHSztDaYnFS8L49UlJoNdbXR59BFQ1OiLe8+iqzqCIliFx2GZMaVq48fQHWGWapqTzd0G3fz7Q0YPhwXtOuHH+pq2PMbNQousQSfB2ddffrIP9ttyWnwJSXszCypgZ45BG6J9x240YiGKQ7000WSyiZmfzuHTtm3kJUXc1HO+Mtn38OXHpp5O9Yairwve8xk+zkyfD/d/AgrZ7/+i93psR7PIxf6hb9ydaROhqU4uY1EOBGwYb70XkrQEoKP4xkEhil6IKYO5ft3hctSp5qYi0sv/oVhcUtFksoHg8tyJYWikK8C5FSDObb2e+quprz+fa3237Oj35EUX3sMWPONTXAgw/S4rrxRvd+R3WDS13/ojsldAW0O2zLFuDnP4/7eOdYceYqoAUmGVxkOt147lwjxnL22e69aUNJBmHRdO/OhejEifgL8ZTijrFPH/uu8wcf8HHUqLaf06MHMHs2F6IdOxh3uv127vYXL3Z/s1Qdf0lJYf1LVwnwV1TQHXbBBba26nHutyclhQuxx0OBUYqN89y0eIUWSKal0UzNyRFhcSIeD6/NwYPcDAwZEvt89FntdhZP/u1vFMv2rCePh/Uva9dyh6v7nz39NJDo45KtIjWVDS5372aAf9Qod39PO8LvZ6q5tkhtnKtzxQUwLBgAWLWKj24RGN3SRVfcL1qUXFlhySQsmpQUpugeOcLal1ha8ivFbMCUFPs+E6W4U28r3hJKaiort199FdizhyKTLC5bTY8eDPDv3QscOAAMHZpc89MoxY3BoUPsjGFzQbazxQVwp8DoNMhHH2WweOFC2/yeppOswqLp1o2V++Xl/HdGRucXouZmexsoNjQw8+3666N7fp8+jHEqlTyHmoWi65lyc5k9lp7OfmTJNE+lgHXraLF+//vAhRfaPj/niwtgCIzHQ4FRikFHJy5qSgE7d9LnqZtQZmTYPSpzSHZhAYyFqLGR1kt6OkUm2hu1oYGPdnZb2LmTjxdeGP3fJKOohKKTNnw+uj29XmeVAXzxBfujxXo/7d/PUoAJE5io4YB5uWdlSEmhBTB7NiuLH3/ceUF+pdiefPFi7owWLRJhcSN6IUpL40IUbfW+UrR4vF57g+Fr13LsydIA1Sx0Bln//rRgzEw9j4fychY5vv12bOOprGScpW9fljg4pFWPOywXjRYYwHkuMqVYM7BqFQ/6mjfP6OfkdrqSsGhSUrhBKCpiFXtubsdzDgbpEuvXz94W+598krxxhXjxeHjsAMDYWn09r216un2fV9++bLq5ciXP2enVK/q/bWxkx+rmZuCZZxy15rhvhQi1YFatApYssd+C0XnlTz7JbJT58x11keMiGGQq+K9+Rb98VxAWTWoqA/w+X8ftYZRiCihgr0ssEGCNy4wZ9o3B6WiByc1lXc+uXbRQ/X57LBmPhxZHIBBec9QRPh9LHI4epScnAQeAdQZ3rhKtXWR2CoxSwFtvcfEdN46xoGQ6QXLpUu6IbrsN+OUvu46wALzpu3enJVJb236BpVL8/7162WsxFBfzcdo0+8bgBrTrMy+PMbbiYqYrV1baIzBZWWws+dFHQEFBx2NobmYnjB072KcwL89xlqq73GKhhLrIVq/mY6KD/EoB//d/wIsv8lzyu+5yV5O/9tCusK4qLBod4G9uZoFlaurpAqIULRsgtvRlM3nnHaNmR2gfj4fusKFDedz14cPAvn2Mk+ruyom6lh4P8M1vUlxWrqRLti3BqK1lotBnnwH33suzdhwmLIBbLReNFpjbb6fALF6cOAsmGAT+9CcKyxVXAHffnVzCEuoK66rCotGLdffurF6vrTV2lrrVS20tfed2f07r1jljHG7C42FNyKhRrPFpaGBftqIiFpUmypJJSaE1kpvLAP1HH4WvZ8Egx/WTn9DKeughWxpSRot7LRdNSgr9lQBPPQSA6dOtfU+lgP/9X4pLfj7dcw7J0DCFruwKawuPh8fDHjvGPnH19czGamqiO6xHD6N9v50cOsQdsN3jcCMpKbRgMjN5ncvK6CYbNChxru7u3Rm7nTuX997llwOTJ1NYPv4Y+PBDjnHlSoqhg6+z+8UFCBeYZ5+lOWkllZXAX/7Cvj0//GFyCcvx46zYFmE5nZQUCkxlJTsJ19Xx5s7IYHqr3Td6fT2/i9dea+843IzHw5Y5OmW5qIhV/eedl7gxZGRQPNasAQoLacF4PLRIb7oJ+M537I/tRYEl4lKxfTuCrdt4J4IrrgACAVQdPw4r8yZOVlQgOGMG+y/t2GHhO0WmtrgYWcOGWfLaFSUlCN58M1uHbNxoyXu0R9X+/cgZO9aS1w40NUGZ4TZNT2eQ3+/nYu71MiW0A4I+H1It3AFXHzgANXs2M6A+/NCy92mLk4cOYUB7jTLjwF9XhyY7Gk/m5AA9esDf3Ayrelyf3LMHqq7u9P9x3nm0mnQWYv/+tKD27DH1/WsOH0b/ESNMfU0A8ChlrkNRKYWq/fvtVVWlkDlsGDwWjEEphbriYtvn12vgQNPnp5RC1Vdf2T63zKFDLZlbMN5uxyaQkpZm2ffyZFGR7dcuY8gQS66dv6nJ1NeMhdRu3SyZW83hw7Zft965uabPzXRxMQXdstzrdbzp12kCAfrpu3VLLncawOvm9zOjKtmuG0C/d0sLEzeSzV2oC0DT02VubiMYNNYUB83POSMJZdcu5qDv2mX3SMynqAi4+WY+Jhu7d9NXvXu33SOxBp+PAXMzj0N2Cnv3MqV17167R2I+jY1M243CdelK9u9nnG3/frtHEoYzxUUQBEFwNSIugiAIgumIuAiCIAimI+IiCIIgmI6IiyAIgmA6Ii6CIAiC6Yi4CIIgCKYj4iIIgiCYjoiLIAiCYDoiLoIgCILpiLgIgiAIpiPiIgiCIJiOiIsgCIJgOiIugiAIgumIuAiCIAimI+IiCIIgmI6IiyAIgmA6Ii6CIAiC6Yi4CIIgCKYj4iIIgiCYjoiLIAiCYDoiLoIgCILpiLgIgiAIpiPiIgiCIJiOiIsgCIJgOiIugiAIgumIuAiCIAimI+IiCIIgmI6IiyAIgmA6Ii6CIAiC6Yi4CIIgCKYj4iIIgiCYjoiLIAiCYDoiLoIgCILpiLgIgiAIpiPiIgiCIJiOiIsgCIJgOiIugiAIgumIuAiCIAim40hxqa6uRiAQQHV1td1DMZ2qqirU1dWhqqrK7qGYTnV1NfxJet0AoKKiAkVFRaioqLB7KKZTVV2NpuZmVCXhtauoqMDRo0eT8roBXFMaGxudt6YoB7F27Vo1ffp0lQeoMkDlAWr69OmqoKDA7qHFjZ7bOYD6E6DOScK5XQCoYkBdkERzU8qYXzqgzgNUehLNT8/tfEB9Aqjzk3BuPQB1CaB6JNHclDLmNxxQGwE13GHzc4y4LF68WAFQXq83TFy8Xq8CoJYsWWL3EGMmdG6h4pJscwsVl2SYm1Lh8wsVl2SYX+jcQsUl2eYWKi7JMDelwucXKi5Omp8jxGXt2rUKwNc/oeIS+nunKHJnaD23UHFJtrmFiovb56bU6fMLFRe3z6/13ELFJdnmFioubp+bUqfPL1RcnDQ/R8Rcli9fDq/X2+5zvF4vVqxYkaARmYfMzZ1zA5J7fjI3d84NcNH8bJU2pVR5eXmY2qIdywWAKi8vt3vIURNpbm1ZLskwt7YsF7fNTanI82vLcnHb/CLNrS3LJRnm1pbl4ra5KRV5fm1ZLnbPz3bLpbMZDo7LiGgHmVvsz7ebZJ6fzC3259uNm+Znu7hkZmae9rt9AKafeozm+U4l0liLAdxz6jGa5zuVSGPdD2DGqcdonu9kIo3XB6Do1GM0z3cqkcZ6CMAPTj1G83ynEmmsjQB2nnqM5vlOJtJ4DwO47dRjNM9PFLaLS9++fTF9+vQwH2ITgF2nHjVerxf5+fno27dvoocYM5Hm5gNwEOELVLLMrQnAbrj/ugGR56cANJ961LhxfpHm1gzgy1OPmmSZmwLQAPdfN6Dta7cPDrx2tjnkQigoKDjNVxjpx+7sh1iQublzbkol9/xkbu6cm1LumZ8jxEWp8Lzt0A/ISXnbsSJzcy/JPD+Zm3txw/wcIy5KUZHz8/PDPqz8/HzbFdgMZG7uJZnnJ3NzL06fn0cpFeqKdAQVFRWoqqpCZmam63yiHSFzcy/JPD+Zm3tx6vwcKS6CIAiCu7E9W0wQBEFIPkRcBEEQBNMRcREEQRBMR8RFEARBMB0RF0EQBMF0RFwEQRAE0xFxEQRBEExHxEUQBEEwHREXQRAEwXREXARBEATTEXERBEEQTEfERRAEQTAdERdBEATBdERcBEEQBNMRcREEQRBM5/8BkwoOtfBkw+UAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from coxkan.datasets import gbsg\n", "\n", "# load dataset\n", "df_train, df_test = gbsg.load(split=True)\n", "name, duration_col, event_col, covariates = gbsg.metadata()\n", "\n", "# init CoxKAN\n", "ckan = CoxKAN(width=[len(covariates), 1], seed=42)\n", "\n", "# pre-process and register data\n", "df_train, df_test = ckan.process_data(df_train, df_test, duration_col, event_col, normalization='standard')\n", "\n", "# train CoxKAN\n", "_ = ckan.train(\n", " df_train, \n", " df_test, \n", " duration_col=duration_col, \n", " event_col=event_col,\n", " opt='Adam',\n", " lr=0.01,\n", " steps=100)\n", "\n", "print(\"\\nCoxKAN C-Index: \", ckan.cindex(df_test))\n", "\n", "# Auto symbolic fitting\n", "fit_success = ckan.auto_symbolic(verbose=False)\n", "display(ckan.symbolic_formula(floating_digit=2)[0][0])\n", "\n", "# Plot coxkan\n", "fig = ckan.plot(beta=20)" ] } ], "metadata": { "kernelspec": { "display_name": "wdk", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }