Note
Go to the end to download the full example code.
Kolmogorov-Arnold Networks
This script implements a model using Kolmogorov-Arnold networks. It fits to a simple quadratic surface using only a few parameters.
Model summary:
Input - Output: (2)
DenseKAN - (Input, Output): (2, 2) - Parameters: 40
Tanh - Output: (2)
DenseKAN - (Input, Output): (2, 2) - Parameters: 40
Dense - (Input, Output): (2, 1) - Parameters: 3
Total number of parameters: 83
Epoch: 1 - Metrics: {'loss': '18.0900'}
Epoch: 2 - Metrics: {'loss': '17.9226'}
Epoch: 3 - Metrics: {'loss': '17.7467'}
Epoch: 4 - Metrics: {'loss': '17.5625'}
Epoch: 5 - Metrics: {'loss': '17.3699'}
Epoch: 6 - Metrics: {'loss': '17.1690'}
Epoch: 7 - Metrics: {'loss': '16.9600'}
Epoch: 8 - Metrics: {'loss': '16.7431'}
Epoch: 9 - Metrics: {'loss': '16.5185'}
Epoch: 10 - Metrics: {'loss': '16.2866'}
Epoch: 11 - Metrics: {'loss': '16.0475'}
Epoch: 12 - Metrics: {'loss': '15.8014'}
Epoch: 13 - Metrics: {'loss': '15.5486'}
Epoch: 14 - Metrics: {'loss': '15.2890'}
Epoch: 15 - Metrics: {'loss': '15.0225'}
Epoch: 16 - Metrics: {'loss': '14.7489'}
Epoch: 17 - Metrics: {'loss': '14.4679'}
Epoch: 18 - Metrics: {'loss': '14.1788'}
Epoch: 19 - Metrics: {'loss': '13.8810'}
Epoch: 20 - Metrics: {'loss': '13.5738'}
Epoch: 21 - Metrics: {'loss': '13.2568'}
Epoch: 22 - Metrics: {'loss': '12.9298'}
Epoch: 23 - Metrics: {'loss': '12.5930'}
Epoch: 24 - Metrics: {'loss': '12.2468'}
Epoch: 25 - Metrics: {'loss': '11.8919'}
Epoch: 26 - Metrics: {'loss': '11.5288'}
Epoch: 27 - Metrics: {'loss': '11.1585'}
Epoch: 28 - Metrics: {'loss': '10.7816'}
Epoch: 29 - Metrics: {'loss': '10.3990'}
Epoch: 30 - Metrics: {'loss': '10.0113'}
Epoch: 31 - Metrics: {'loss': '9.6191'}
Epoch: 32 - Metrics: {'loss': '9.2229'}
Epoch: 33 - Metrics: {'loss': '8.8232'}
Epoch: 34 - Metrics: {'loss': '8.4205'}
Epoch: 35 - Metrics: {'loss': '8.0154'}
Epoch: 36 - Metrics: {'loss': '7.6088'}
Epoch: 37 - Metrics: {'loss': '7.2016'}
Epoch: 38 - Metrics: {'loss': '6.7948'}
Epoch: 39 - Metrics: {'loss': '6.3896'}
Epoch: 40 - Metrics: {'loss': '5.9873'}
Epoch: 41 - Metrics: {'loss': '5.5888'}
Epoch: 42 - Metrics: {'loss': '5.1956'}
Epoch: 43 - Metrics: {'loss': '4.8089'}
Epoch: 44 - Metrics: {'loss': '4.4301'}
Epoch: 45 - Metrics: {'loss': '4.0605'}
Epoch: 46 - Metrics: {'loss': '3.7015'}
Epoch: 47 - Metrics: {'loss': '3.3545'}
Epoch: 48 - Metrics: {'loss': '3.0210'}
Epoch: 49 - Metrics: {'loss': '2.7024'}
Epoch: 50 - Metrics: {'loss': '2.4002'}
Epoch: 51 - Metrics: {'loss': '2.1158'}
Epoch: 52 - Metrics: {'loss': '1.8505'}
Epoch: 53 - Metrics: {'loss': '1.6055'}
Epoch: 54 - Metrics: {'loss': '1.3817'}
Epoch: 55 - Metrics: {'loss': '1.1798'}
Epoch: 56 - Metrics: {'loss': '1.0001'}
Epoch: 57 - Metrics: {'loss': '0.8428'}
Epoch: 58 - Metrics: {'loss': '0.7076'}
Epoch: 59 - Metrics: {'loss': '0.5938'}
Epoch: 60 - Metrics: {'loss': '0.5006'}
Epoch: 61 - Metrics: {'loss': '0.4266'}
Epoch: 62 - Metrics: {'loss': '0.3703'}
Epoch: 63 - Metrics: {'loss': '0.3296'}
Epoch: 64 - Metrics: {'loss': '0.3025'}
Epoch: 65 - Metrics: {'loss': '0.2867'}
Epoch: 66 - Metrics: {'loss': '0.2798'}
Epoch: 67 - Metrics: {'loss': '0.2795'}
Epoch: 68 - Metrics: {'loss': '0.2835'}
Epoch: 69 - Metrics: {'loss': '0.2897'}
Epoch: 70 - Metrics: {'loss': '0.2964'}
Epoch: 71 - Metrics: {'loss': '0.3021'}
Epoch: 72 - Metrics: {'loss': '0.3057'}
Epoch: 73 - Metrics: {'loss': '0.3064'}
Epoch: 74 - Metrics: {'loss': '0.3037'}
Epoch: 75 - Metrics: {'loss': '0.2975'}
Epoch: 76 - Metrics: {'loss': '0.2880'}
Epoch: 77 - Metrics: {'loss': '0.2754'}
Epoch: 78 - Metrics: {'loss': '0.2605'}
Epoch: 79 - Metrics: {'loss': '0.2438'}
Epoch: 80 - Metrics: {'loss': '0.2261'}
Epoch: 81 - Metrics: {'loss': '0.2082'}
Epoch: 82 - Metrics: {'loss': '0.1909'}
Epoch: 83 - Metrics: {'loss': '0.1750'}
Epoch: 84 - Metrics: {'loss': '0.1612'}
Epoch: 85 - Metrics: {'loss': '0.1498'}
Epoch: 86 - Metrics: {'loss': '0.1412'}
Epoch: 87 - Metrics: {'loss': '0.1353'}
Epoch: 88 - Metrics: {'loss': '0.1318'}
Epoch: 89 - Metrics: {'loss': '0.1303'}
Epoch: 90 - Metrics: {'loss': '0.1302'}
Epoch: 91 - Metrics: {'loss': '0.1310'}
Epoch: 92 - Metrics: {'loss': '0.1321'}
Epoch: 93 - Metrics: {'loss': '0.1330'}
Epoch: 94 - Metrics: {'loss': '0.1335'}
Epoch: 95 - Metrics: {'loss': '0.1333'}
Epoch: 96 - Metrics: {'loss': '0.1323'}
Epoch: 97 - Metrics: {'loss': '0.1306'}
Epoch: 98 - Metrics: {'loss': '0.1284'}
Epoch: 99 - Metrics: {'loss': '0.1257'}
Epoch: 100 - Metrics: {'loss': '0.1228'}
Epoch: 101 - Metrics: {'loss': '0.1198'}
Epoch: 102 - Metrics: {'loss': '0.1169'}
Epoch: 103 - Metrics: {'loss': '0.1142'}
Epoch: 104 - Metrics: {'loss': '0.1118'}
Epoch: 105 - Metrics: {'loss': '0.1097'}
Epoch: 106 - Metrics: {'loss': '0.1079'}
Epoch: 107 - Metrics: {'loss': '0.1065'}
Epoch: 108 - Metrics: {'loss': '0.1053'}
Epoch: 109 - Metrics: {'loss': '0.1044'}
Epoch: 110 - Metrics: {'loss': '0.1036'}
Epoch: 111 - Metrics: {'loss': '0.1028'}
Epoch: 112 - Metrics: {'loss': '0.1021'}
Epoch: 113 - Metrics: {'loss': '0.1014'}
Epoch: 114 - Metrics: {'loss': '0.1006'}
Epoch: 115 - Metrics: {'loss': '0.0998'}
Epoch: 116 - Metrics: {'loss': '0.0990'}
Epoch: 117 - Metrics: {'loss': '0.0981'}
Epoch: 118 - Metrics: {'loss': '0.0971'}
Epoch: 119 - Metrics: {'loss': '0.0962'}
Epoch: 120 - Metrics: {'loss': '0.0953'}
Epoch: 121 - Metrics: {'loss': '0.0945'}
Epoch: 122 - Metrics: {'loss': '0.0937'}
Epoch: 123 - Metrics: {'loss': '0.0929'}
Epoch: 124 - Metrics: {'loss': '0.0923'}
Epoch: 125 - Metrics: {'loss': '0.0917'}
Epoch: 126 - Metrics: {'loss': '0.0912'}
Epoch: 127 - Metrics: {'loss': '0.0907'}
Epoch: 128 - Metrics: {'loss': '0.0903'}
Epoch: 129 - Metrics: {'loss': '0.0899'}
Epoch: 130 - Metrics: {'loss': '0.0894'}
Epoch: 131 - Metrics: {'loss': '0.0890'}
Epoch: 132 - Metrics: {'loss': '0.0886'}
Epoch: 133 - Metrics: {'loss': '0.0881'}
Epoch: 134 - Metrics: {'loss': '0.0877'}
Epoch: 135 - Metrics: {'loss': '0.0872'}
Epoch: 136 - Metrics: {'loss': '0.0867'}
Epoch: 137 - Metrics: {'loss': '0.0862'}
Epoch: 138 - Metrics: {'loss': '0.0857'}
Epoch: 139 - Metrics: {'loss': '0.0853'}
Epoch: 140 - Metrics: {'loss': '0.0848'}
Epoch: 141 - Metrics: {'loss': '0.0844'}
Epoch: 142 - Metrics: {'loss': '0.0839'}
Epoch: 143 - Metrics: {'loss': '0.0835'}
Epoch: 144 - Metrics: {'loss': '0.0831'}
Epoch: 145 - Metrics: {'loss': '0.0827'}
Epoch: 146 - Metrics: {'loss': '0.0822'}
Epoch: 147 - Metrics: {'loss': '0.0818'}
Epoch: 148 - Metrics: {'loss': '0.0814'}
Epoch: 149 - Metrics: {'loss': '0.0810'}
Epoch: 150 - Metrics: {'loss': '0.0806'}
Epoch: 151 - Metrics: {'loss': '0.0801'}
Epoch: 152 - Metrics: {'loss': '0.0797'}
Epoch: 153 - Metrics: {'loss': '0.0793'}
Epoch: 154 - Metrics: {'loss': '0.0788'}
Epoch: 155 - Metrics: {'loss': '0.0784'}
Epoch: 156 - Metrics: {'loss': '0.0780'}
Epoch: 157 - Metrics: {'loss': '0.0775'}
Epoch: 158 - Metrics: {'loss': '0.0771'}
Epoch: 159 - Metrics: {'loss': '0.0767'}
Epoch: 160 - Metrics: {'loss': '0.0762'}
Epoch: 161 - Metrics: {'loss': '0.0758'}
Epoch: 162 - Metrics: {'loss': '0.0753'}
Epoch: 163 - Metrics: {'loss': '0.0749'}
Epoch: 164 - Metrics: {'loss': '0.0744'}
Epoch: 165 - Metrics: {'loss': '0.0740'}
Epoch: 166 - Metrics: {'loss': '0.0735'}
Epoch: 167 - Metrics: {'loss': '0.0731'}
Epoch: 168 - Metrics: {'loss': '0.0726'}
Epoch: 169 - Metrics: {'loss': '0.0721'}
Epoch: 170 - Metrics: {'loss': '0.0717'}
Epoch: 171 - Metrics: {'loss': '0.0712'}
Epoch: 172 - Metrics: {'loss': '0.0707'}
Epoch: 173 - Metrics: {'loss': '0.0702'}
Epoch: 174 - Metrics: {'loss': '0.0697'}
Epoch: 175 - Metrics: {'loss': '0.0692'}
Epoch: 176 - Metrics: {'loss': '0.0687'}
Epoch: 177 - Metrics: {'loss': '0.0682'}
Epoch: 178 - Metrics: {'loss': '0.0677'}
Epoch: 179 - Metrics: {'loss': '0.0671'}
Epoch: 180 - Metrics: {'loss': '0.0666'}
Epoch: 181 - Metrics: {'loss': '0.0661'}
Epoch: 182 - Metrics: {'loss': '0.0655'}
Epoch: 183 - Metrics: {'loss': '0.0650'}
Epoch: 184 - Metrics: {'loss': '0.0644'}
Epoch: 185 - Metrics: {'loss': '0.0639'}
Epoch: 186 - Metrics: {'loss': '0.0633'}
Epoch: 187 - Metrics: {'loss': '0.0628'}
Epoch: 188 - Metrics: {'loss': '0.0622'}
Epoch: 189 - Metrics: {'loss': '0.0617'}
Epoch: 190 - Metrics: {'loss': '0.0611'}
Epoch: 191 - Metrics: {'loss': '0.0605'}
Epoch: 192 - Metrics: {'loss': '0.0600'}
Epoch: 193 - Metrics: {'loss': '0.0594'}
Epoch: 194 - Metrics: {'loss': '0.0589'}
Epoch: 195 - Metrics: {'loss': '0.0583'}
Epoch: 196 - Metrics: {'loss': '0.0577'}
Epoch: 197 - Metrics: {'loss': '0.0572'}
Epoch: 198 - Metrics: {'loss': '0.0567'}
Epoch: 199 - Metrics: {'loss': '0.0561'}
Epoch: 200 - Metrics: {'loss': '0.0556'}
Epoch: 201 - Metrics: {'loss': '0.0551'}
Epoch: 202 - Metrics: {'loss': '0.0545'}
Epoch: 203 - Metrics: {'loss': '0.0540'}
Epoch: 204 - Metrics: {'loss': '0.0535'}
Epoch: 205 - Metrics: {'loss': '0.0530'}
Epoch: 206 - Metrics: {'loss': '0.0525'}
Epoch: 207 - Metrics: {'loss': '0.0521'}
Epoch: 208 - Metrics: {'loss': '0.0516'}
Epoch: 209 - Metrics: {'loss': '0.0511'}
Epoch: 210 - Metrics: {'loss': '0.0507'}
Epoch: 211 - Metrics: {'loss': '0.0502'}
Epoch: 212 - Metrics: {'loss': '0.0498'}
Epoch: 213 - Metrics: {'loss': '0.0494'}
Epoch: 214 - Metrics: {'loss': '0.0490'}
Epoch: 215 - Metrics: {'loss': '0.0486'}
Epoch: 216 - Metrics: {'loss': '0.0482'}
Epoch: 217 - Metrics: {'loss': '0.0478'}
Epoch: 218 - Metrics: {'loss': '0.0474'}
Epoch: 219 - Metrics: {'loss': '0.0471'}
Epoch: 220 - Metrics: {'loss': '0.0467'}
Epoch: 221 - Metrics: {'loss': '0.0464'}
Epoch: 222 - Metrics: {'loss': '0.0460'}
Epoch: 223 - Metrics: {'loss': '0.0457'}
Epoch: 224 - Metrics: {'loss': '0.0454'}
Epoch: 225 - Metrics: {'loss': '0.0451'}
Epoch: 226 - Metrics: {'loss': '0.0447'}
Epoch: 227 - Metrics: {'loss': '0.0444'}
Epoch: 228 - Metrics: {'loss': '0.0441'}
Epoch: 229 - Metrics: {'loss': '0.0438'}
Epoch: 230 - Metrics: {'loss': '0.0436'}
Epoch: 231 - Metrics: {'loss': '0.0433'}
Epoch: 232 - Metrics: {'loss': '0.0430'}
Epoch: 233 - Metrics: {'loss': '0.0427'}
Epoch: 234 - Metrics: {'loss': '0.0425'}
Epoch: 235 - Metrics: {'loss': '0.0422'}
Epoch: 236 - Metrics: {'loss': '0.0419'}
Epoch: 237 - Metrics: {'loss': '0.0417'}
Epoch: 238 - Metrics: {'loss': '0.0414'}
Epoch: 239 - Metrics: {'loss': '0.0412'}
Epoch: 240 - Metrics: {'loss': '0.0410'}
Epoch: 241 - Metrics: {'loss': '0.0407'}
Epoch: 242 - Metrics: {'loss': '0.0405'}
Epoch: 243 - Metrics: {'loss': '0.0402'}
Epoch: 244 - Metrics: {'loss': '0.0400'}
Epoch: 245 - Metrics: {'loss': '0.0398'}
Epoch: 246 - Metrics: {'loss': '0.0396'}
Epoch: 247 - Metrics: {'loss': '0.0393'}
Epoch: 248 - Metrics: {'loss': '0.0391'}
Epoch: 249 - Metrics: {'loss': '0.0389'}
Epoch: 250 - Metrics: {'loss': '0.0387'}
Epoch: 251 - Metrics: {'loss': '0.0385'}
Epoch: 252 - Metrics: {'loss': '0.0383'}
Epoch: 253 - Metrics: {'loss': '0.0380'}
Epoch: 254 - Metrics: {'loss': '0.0378'}
Epoch: 255 - Metrics: {'loss': '0.0376'}
Epoch: 256 - Metrics: {'loss': '0.0374'}
Epoch: 257 - Metrics: {'loss': '0.0372'}
Epoch: 258 - Metrics: {'loss': '0.0370'}
Epoch: 259 - Metrics: {'loss': '0.0368'}
Epoch: 260 - Metrics: {'loss': '0.0366'}
Epoch: 261 - Metrics: {'loss': '0.0364'}
Epoch: 262 - Metrics: {'loss': '0.0362'}
Epoch: 263 - Metrics: {'loss': '0.0360'}
Epoch: 264 - Metrics: {'loss': '0.0358'}
Epoch: 265 - Metrics: {'loss': '0.0356'}
Epoch: 266 - Metrics: {'loss': '0.0354'}
Epoch: 267 - Metrics: {'loss': '0.0352'}
Epoch: 268 - Metrics: {'loss': '0.0351'}
Epoch: 269 - Metrics: {'loss': '0.0349'}
Epoch: 270 - Metrics: {'loss': '0.0347'}
Epoch: 271 - Metrics: {'loss': '0.0345'}
Epoch: 272 - Metrics: {'loss': '0.0343'}
Epoch: 273 - Metrics: {'loss': '0.0341'}
Epoch: 274 - Metrics: {'loss': '0.0339'}
Epoch: 275 - Metrics: {'loss': '0.0338'}
Epoch: 276 - Metrics: {'loss': '0.0336'}
Epoch: 277 - Metrics: {'loss': '0.0334'}
Epoch: 278 - Metrics: {'loss': '0.0332'}
Epoch: 279 - Metrics: {'loss': '0.0330'}
Epoch: 280 - Metrics: {'loss': '0.0329'}
Epoch: 281 - Metrics: {'loss': '0.0327'}
Epoch: 282 - Metrics: {'loss': '0.0325'}
Epoch: 283 - Metrics: {'loss': '0.0324'}
Epoch: 284 - Metrics: {'loss': '0.0322'}
Epoch: 285 - Metrics: {'loss': '0.0320'}
Epoch: 286 - Metrics: {'loss': '0.0318'}
Epoch: 287 - Metrics: {'loss': '0.0317'}
Epoch: 288 - Metrics: {'loss': '0.0315'}
Epoch: 289 - Metrics: {'loss': '0.0313'}
Epoch: 290 - Metrics: {'loss': '0.0312'}
Epoch: 291 - Metrics: {'loss': '0.0310'}
Epoch: 292 - Metrics: {'loss': '0.0309'}
Epoch: 293 - Metrics: {'loss': '0.0307'}
Epoch: 294 - Metrics: {'loss': '0.0305'}
Epoch: 295 - Metrics: {'loss': '0.0304'}
Epoch: 296 - Metrics: {'loss': '0.0302'}
Epoch: 297 - Metrics: {'loss': '0.0301'}
Epoch: 298 - Metrics: {'loss': '0.0299'}
Epoch: 299 - Metrics: {'loss': '0.0298'}
Epoch: 300 - Metrics: {'loss': '0.0296'}
import torch
import matplotlib.pyplot as plt
from DLL.DeepLearning.Layers._DenseKAN import _get_basis_functions, _NeuronKAN
from DLL.DeepLearning.Layers import DenseKAN, Dense
from DLL.DeepLearning.Layers.Activations import Tanh
from DLL.DeepLearning.Model import Model
from DLL.DeepLearning.Optimisers import ADAM
from DLL.DeepLearning.Losses import MSE
from DLL.DeepLearning.Initialisers import Xavier_Normal
from DLL.Data.Preprocessing import data_split
# X = torch.linspace(-1, 1, 100).unsqueeze(1)
# y = torch.sin(4 * X).squeeze()
# X = torch.linspace(-1, 1, 50).unsqueeze(1)
# y = (0.5 * torch.sin(4 * X) * torch.exp(-X - 1) + 0.5).squeeze()
n = 30
X, Y = torch.meshgrid(torch.linspace(-1, 1, n), torch.linspace(-1, 1, n), indexing="xy")
x = torch.stack((X.flatten(), Y.flatten()), dim=1)
y = X.flatten() ** 2 + Y.flatten() ** 2 + 0.1 * torch.randn(size=Y.flatten().size()) - 5
X, y, _, _, X_test, y_test = data_split(x, y, train_split=0.8, validation_split=0.0)
model = Model(2)
# model.add(DenseKAN(1, activation=Tanh(), initialiser=Xavier_Normal()))
model.add(DenseKAN(2, activation=Tanh(), initialiser=Xavier_Normal()))
# model.add(DenseKAN(2, activation=Tanh(), initialiser=Xavier_Normal()))
model.add(DenseKAN(2, initialiser=Xavier_Normal()))
model.add(Dense(0, initialiser=Xavier_Normal()))
model.compile(ADAM(0.01), MSE())
model.summary()
history = model.fit(X, y, epochs=300, verbose=True)
# X_test = 2 * torch.rand_like(X) - 1
# y_test = torch.sin(4 * X_test).squeeze()
# X_test = 2 * torch.rand_like(X) - 1
# y_test = (0.5 * torch.sin(4 * X_test) * torch.exp(-X_test - 1) + 0.5).squeeze()
# plt.figure()
# plt.scatter(X_test, y_test, label="True test points")
# plt.scatter(X_test, model.predict(X_test), label="Predictions")
# plt.scatter(X, y, label="True train points")
# plt.scatter(X, model.predict(X), label="Predicted train points")
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X_test[:, 0], X_test[:, 1], y_test, label="True test points")
ax.scatter(X_test[:, 0], X_test[:, 1], model.predict(X_test), label="Predictions")
plt.legend()
plt.figure(figsize=(8, 8))
plt.plot(history["loss"])
plt.title("Loss function as a function of epoch")
plt.ylabel("MSE")
plt.xlabel("Epoch")
plt.show()
n_fun = 10
basis_funcs, basis_func_derivatives = _get_basis_functions(n_fun, degree=5, bounds=(-1, 1))
x = torch.linspace(-1, 1, 100).unsqueeze(1)
fig, ax = plt.subplots(1, 2, figsize=(8, 6))
plt.subplots_adjust(wspace=0.3)
for i in range(n_fun):
basis_values = basis_funcs[i](x)
derivative_values = basis_func_derivatives[i](x)
ax[0].plot(x.squeeze(1), basis_values)
ax[1].plot(x, derivative_values)
ax[0].set_title("B-spline Basis Functions")
ax[0].set_xlabel("x")
ax[0].set_ylabel("y")
ax[0].grid()
ax[1].set_title("B-spline Basis Function Derivatives")
ax[1].set_xlabel("x")
ax[1].set_ylabel("y")
ax[1].grid()
plt.show()
Total running time of the script: (0 minutes 5.562 seconds)