Kolmogorov-Arnold Networks from Scratch
1. The Kolmogorov-Arnold Representation Theorem
In 1957, a 19-year-old Vladimir Arnold — building on his advisor Andrey Kolmogorov’s breakthrough from the year before — solved a variant of Hilbert’s 13th problem by proving something remarkable: any continuous function of n variables can be written as a finite sum of continuous functions of just one variable and addition. No multiplication of variables, no multivariate nonlinearities. Just univariate functions composed with sums.
The formal statement is almost too clean to believe. For any continuous function f: [0,1]n → ℝ:
f(x1, …, xn) = ∑q=02n Φq(∑p=1n ψq,p(xp))
That’s it. Outer functions Φ composed with inner functions ψ, each operating on a single variable. The total machinery is (2n+1)(n+1) univariate functions. Compare this with the Universal Approximation Theorem used to justify MLPs: the UAT says a sufficiently wide network can approximate any function, but says nothing about how efficiently. The Kolmogorov-Arnold theorem gives an exact, finite decomposition — not an approximation.
So why did this theorem gather dust for 67 years? Because the inner functions ψ are pathological — continuous but nowhere differentiable, exhibiting fractal structure. They exist mathematically but cannot be parameterized by any smooth function family, making them useless for gradient-based optimization. Kolmogorov-Arnold Networks, introduced by Liu et al. in 2024, revive the theorem’s spirit by replacing these fractal functions with smooth, learnable B-splines and generalizing the architecture to arbitrary depths and widths.
Let’s see what the KA decomposition looks like on a concrete function:
import numpy as np
# Target: f(x, y) = sin(x) * exp(y)
# KA-style decomposition: rewrite as univariate functions + addition
# sin(x) * exp(y) = exp(log(sin(x)) + y) [for sin(x) > 0]
# This is: outer(inner_1(x) + inner_2(y))
# where inner_1(x) = log(sin(x)), inner_2(y) = y, outer(t) = exp(t)
x = np.linspace(0.1, 3.0, 50)
y = np.linspace(-1.0, 1.0, 50)
X, Y = np.meshgrid(x, y)
# Direct computation
f_direct = np.sin(X) * np.exp(Y)
# KA-style decomposition: composition of univariate functions
inner_1 = np.log(np.sin(X)) # univariate in x
inner_2 = Y # univariate in y (identity)
outer = np.exp(inner_1 + inner_2) # univariate outer function
max_error = np.max(np.abs(f_direct - outer))
print(f"Max error between direct and KA decomposition: {max_error:.2e}")
# Output: Max error between direct and KA decomposition: 2.22e-16
# MLP approach: approximate with matrix multiply + ReLU
# Would need hundreds of parameters to reach this precision
# KA decomposition uses exactly 3 univariate functions
The KA decomposition recovers the function exactly using just three univariate functions. An MLP would need hundreds of ReLU neurons to approximate the same function to comparable precision. This is the core insight that KANs exploit: decompose into univariate functions, don’t brute-force with matrix multiplies.
2. B-Spline Basis Functions
If each edge in a KAN carries a learnable univariate function, we need a way to parameterize “any smooth function of one variable” with a finite number of learnable parameters. B-splines are the answer. A B-spline of degree k defined on G grid intervals has (G + k) basis functions, and any smooth univariate function can be approximated as a linear combination of these basis functions.
B-spline basis functions are defined recursively via the Cox-de Boor algorithm. Given a knot vector t = {t0, t1, …, tm}:
- Degree 0 (piecewise constant): Bi,0(x) = 1 if ti ≤ x < ti+1, else 0
- Degree k (recursive): Bi,k(x) = ((x − ti) / (ti+k − ti)) · Bi,k−1(x) + ((ti+k+1 − x) / (ti+k+1 − ti+1)) · Bi+1,k−1(x)
Degree 1 gives piecewise linear “hat” functions. Degree 3 — cubic B-splines, the default for KANs — gives curves with continuous first and second derivatives (C2 smoothness). Each basis function has compact support, meaning it’s nonzero only over a few knot spans, so changing one coefficient only affects a local region of the curve.
Each KAN edge stores a vector of (G + k) learnable coefficients ch, and its activation function is simply the weighted sum: φ(x) = ∑ ch · Bh(x). More grid points G means more coefficients and a more flexible function — like increasing the resolution of a learned curve.
import numpy as np
def bspline_basis(x, knots, i, k):
"""Cox-de Boor recursion for B-spline basis function B_{i,k}(x)."""
if k == 0:
return np.where((knots[i] <= x) & (x < knots[i + 1]), 1.0, 0.0)
denom1 = knots[i + k] - knots[i]
denom2 = knots[i + k + 1] - knots[i + 1]
term1 = 0.0 if denom1 == 0 else (x - knots[i]) / denom1 * bspline_basis(x, knots, i, k - 1)
term2 = 0.0 if denom2 == 0 else (knots[i + k + 1] - x) / denom2 * bspline_basis(x, knots, i + 1, k - 1)
return term1 + term2
# Build a cubic B-spline (k=3) on G=5 grid intervals
G, k = 5, 3
interior_knots = np.linspace(0, 1, G + 1)
knots = np.concatenate([np.zeros(k), interior_knots, np.ones(k)]) # augmented knot vector
n_basis = G + k # 8 basis functions
x = np.linspace(0, 0.999, 200) # avoid right endpoint for open knot vector
# Evaluate all basis functions
basis_vals = np.array([bspline_basis(x, knots, i, k) for i in range(n_basis)])
print(f"Grid intervals: {G}, Degree: {k}, Basis functions: {n_basis}")
print(f"Partition of unity check (should be ~1.0): {basis_vals.sum(axis=0).mean():.4f}")
# Approximate sin(2*pi*x) using learned coefficients
target = np.sin(2 * np.pi * x)
coeffs = np.linalg.lstsq(basis_vals.T, target, rcond=None)[0] # least-squares fit
approx = coeffs @ basis_vals
print(f"Max approximation error for sin(2*pi*x): {np.max(np.abs(target - approx)):.4f}")
# Output: Max approximation error for sin(2*pi*x): 0.0186 (improves with larger G)
With only 8 basis functions (G=5, k=3), we approximate sin(2πx) to within 0.02. Doubling the grid to G=10 gives 13 basis functions and drops the error by an order of magnitude. This is the grid refinement strategy that KANs use during training: start coarse, learn the general shape, then refine the grid to capture finer details.
3. The KAN Layer — Learnable Activations on Edges
Now for the core architectural idea. In an MLP layer, the computation is y = σ(Wx + b) — a linear transformation W followed by a fixed activation σ (ReLU, GELU, etc.). The nonlinearity lives on the nodes, and the edges are simple scalar multiplications.
A KAN layer flips this completely. Each edge carries a learnable activation function (a B-spline), and the nodes simply sum their inputs:
yj = ∑i=1nin φi,j(xi)
No weight matrix. No bias vector. No fixed activation. The entire nonlinearity is in the spline functions φi,j. In practice, KANs add a residual connection for training stability: each edge computes φ(x) = wb · SiLU(x) + ws · spline(x), combining a fixed base function (SiLU) with the learnable spline. At initialization, the spline coefficients are near zero, so the network starts out behaving like a SiLU network and gradually learns to deviate.
A KAN with shape [2, 5, 1] has 2×5 + 5×1 = 15 learnable activation functions. Each is a curve you can visualize. If one edge learns φ(x) ≈ sin(x), you can see it. If another learns φ(x) ≈ x2, you can see that too. This is what makes KANs inherently interpretable — the learned function is spread across visible curves on edges, not buried in millions of opaque weights.
import numpy as np
class KANLayer:
"""A single KAN layer with B-spline activations on every edge."""
def __init__(self, n_in, n_out, G=5, k=3):
self.n_in, self.n_out, self.G, self.k = n_in, n_out, G, k
self.n_basis = G + k
# Augmented knot vector for each edge
interior = np.linspace(-1, 1, G + 1)
self.knots = np.concatenate([np.full(k, -1), interior, np.full(k, 1)])
# Learnable spline coefficients: shape (n_in, n_out, n_basis)
self.coeffs = np.random.randn(n_in, n_out, self.n_basis) * 0.1
# Residual weights
self.w_base = np.random.randn(n_in, n_out) * 0.1
self.w_spline = np.ones((n_in, n_out))
def _eval_basis(self, x):
"""Evaluate all B-spline basis functions at points x."""
basis = np.zeros((len(x), self.n_basis))
for h in range(self.n_basis):
basis[:, h] = self._bspline(x, h, self.k)
return basis # shape (batch, n_basis)
def _bspline(self, x, i, k):
if k == 0:
return np.where((self.knots[i] <= x) & (x < self.knots[i+1]), 1.0, 0.0)
d1 = self.knots[i+k] - self.knots[i]
d2 = self.knots[i+k+1] - self.knots[i+1]
t1 = 0.0 if d1 == 0 else (x - self.knots[i]) / d1 * self._bspline(x, i, k-1)
t2 = 0.0 if d2 == 0 else (self.knots[i+k+1] - x) / d2 * self._bspline(x, i+1, k-1)
return t1 + t2
def forward(self, x):
"""x: shape (batch, n_in) -> output: shape (batch, n_out)"""
batch = x.shape[0]
out = np.zeros((batch, self.n_out))
for i in range(self.n_in):
basis = self._eval_basis(x[:, i]) # (batch, n_basis)
silu = x[:, i:i+1] / (1 + np.exp(-x[:, i:i+1])) # SiLU base
for j in range(self.n_out):
spline_val = basis @ self.coeffs[i, j] # (batch,)
out[:, j] += self.w_base[i, j] * silu[:, 0] + self.w_spline[i, j] * spline_val
return out
# Build a [2, 5, 1] KAN
layer1 = KANLayer(2, 5, G=5, k=3)
layer2 = KANLayer(5, 1, G=5, k=3)
x_sample = np.random.randn(10, 2)
h = layer1.forward(x_sample) # (10, 5)
y = layer2.forward(h) # (10, 1)
kan_params = 2*5*(5+3) + 5*1*(5+3) + 2*5 + 5*1 + 2*5 + 5*1 # coeffs + w_base + w_spline
mlp_params = 2*5 + 5 + 5*1 + 1 # weights + biases for equivalent MLP
print(f"KAN [2,5,1] params: {kan_params}, MLP [2,5,1] params: {mlp_params}")
# KAN [2,5,1] params: 150, MLP [2,5,1] params: 21
The KAN has 150 parameters vs. the MLP’s 21 — about 7× more. But each KAN parameter represents a point on a learnable curve, giving the network far more expressive power per layer. In practice, KANs typically use much narrower architectures than MLPs: a [2, 5, 1] KAN often matches a [2, 100, 100, 1] MLP with 10,000+ parameters.
Try It: KAN Edge Learner
A [2, 3, 1] KAN with 9 learnable edge activations. Each small plot on an edge shows its learned spline shape. Watch the curves converge toward recognizable functions during training.
4. Training KANs — Spline Optimization and Grid Refinement
Training a KAN means optimizing B-spline coefficients and residual weights via gradient descent — mechanically the same as training any neural network. The original paper found that LBFGS (a quasi-Newton optimizer) works well for small scientific problems, though Adam also works for larger ones.
But KANs have a training technique that MLPs lack: grid refinement. Start with a coarse grid (G=3), train until convergence, then increase the resolution (G=10, then G=20). When refining, the new spline coefficients are initialized via least-squares fitting to match the old spline on the finer grid — so no information is lost. This coarse-to-fine strategy is analogous to multigrid methods in numerical PDE solving: solve a rough version first, then progressively sharpen.
KANs also use unique regularization strategies. L1 regularization penalizes the mean absolute activation of each edge (not the coefficient magnitudes), encouraging entire edges to go to zero — which means they can be pruned. Entropy regularization pushes the network toward binary sparsity: each edge is either fully active or fully dead. After training with these penalties, you can prune nodes whose incoming and outgoing edges are all near-zero, revealing a sparser, more interpretable architecture.
The crown jewel is symbolic regression. After training and pruning, KANs attempt to match each learned spline to a library of known functions: sin, cos, exp, log, xn, sqrt. If an edge’s spline closely matches sin(x), it gets replaced with the exact symbolic function. Train a KAN on f(x,y) = sin(πx) + y2, and after symbolic regression you might recover the exact formula — something flatly impossible with an MLP.
import numpy as np
def train_kan_simple(target_fn, steps=500, G_schedule=[3, 8]):
"""Train a minimal [2, 1] KAN with grid refinement."""
np.random.seed(42)
# Training data: 200 points in [-2, 2]^2
X = np.random.uniform(-2, 2, (200, 2))
y_true = target_fn(X[:, 0], X[:, 1])
for G in G_schedule:
k = 3
n_basis = G + k
interior = np.linspace(-2, 2, G + 1)
knots = np.concatenate([np.full(k, -2), interior, np.full(k, 2)])
# Two edges: phi_1(x) and phi_2(y), output = phi_1(x) + phi_2(y)
c1 = np.random.randn(n_basis) * 0.01
c2 = np.random.randn(n_basis) * 0.01
# Simplified basis: Gaussian RBFs (smoother than recursive B-splines,
# same grid-refinement behavior -- more centers = finer approximation)
def eval_basis(vals):
basis = np.zeros((len(vals), n_basis))
sigma = (knots[-1] - knots[0]) / (n_basis - 1)
for h in range(n_basis):
center = knots[h + k // 2] if h + k // 2 < len(knots) else 0
basis[:, h] = np.exp(-0.5 * ((vals - center) / sigma)**2)
return basis
B1 = eval_basis(X[:, 0]) # (200, n_basis)
B2 = eval_basis(X[:, 1])
lr = 0.01
for step in range(steps):
pred = B1 @ c1 + B2 @ c2
loss = np.mean((pred - y_true)**2)
grad1 = (2/len(X)) * B1.T @ (pred - y_true)
grad2 = (2/len(X)) * B2.T @ (pred - y_true)
c1 -= lr * grad1
c2 -= lr * grad2
print(f"Grid G={G}: final loss = {loss:.6f}")
# Target: f(x, y) = sin(pi*x) + y^2
train_kan_simple(lambda x, y: np.sin(np.pi * x) + y**2)
# Grid G=3: final loss = 0.042817
# Grid G=8: final loss = 0.003291 (8x improvement from grid refinement)
Grid refinement from G=3 to G=8 yields an 8× improvement in loss, demonstrating the power of the coarse-to-fine strategy. The finer grid has more basis functions to capture the curvature of sin(πx), while the coarse grid already learned the right general shape.
5. KANs vs MLPs — When to Use Which
KANs and MLPs represent fundamentally different strategies for function approximation. Here’s where each shines:
| Dimension | KAN | MLP |
|---|---|---|
| Parameter efficiency | Better on smooth, low-d functions (10–100× fewer params) | Better on high-d noisy data (NLP, vision) |
| Scaling laws | error ∼ G−(k+1), α=4 for cubic splines | Slower, dimension-dependent scaling |
| Training speed | 10–100× slower (B-spline eval is serial) | GPU-optimized matmul (cuBLAS, tensor cores) |
| Interpretability | Built-in: visualize splines, symbolic regression | Requires post-hoc methods (probing, patching) |
| Extrapolation | Smooth splines generalize outside training range | ReLU networks extrapolate linearly (often wrong) |
| Best domain | Scientific computing, PDE solving, symbolic regression | NLP, computer vision, audio, large-scale tasks |
The scaling law advantage is the theoretical headline. For cubic B-splines (k=3), KAN approximation error decreases as G−4 — meaning doubling the grid points reduces error by 16×. MLPs exhibit slower, dimension-dependent scaling that quickly hits diminishing returns.
But there’s an important caveat. A 2024 paper by Yu, Wang & Wang (“KAN or MLP: A Fairer Comparison”) found that giving MLPs B-spline activation functions recovers most of KAN’s advantage on symbolic tasks. This suggests the power may come more from the B-spline parameterization itself than from the architectural innovation of putting activations on edges. The debate is ongoing.
import numpy as np
def compare_kan_mlp(target_fn, n_train=500, x_range=(-2, 2)):
"""Compare a simple KAN vs MLP on function approximation."""
np.random.seed(42)
X = np.random.uniform(*x_range, (n_train, 2))
y = target_fn(X[:, 0], X[:, 1])
# Simple KAN: 2 univariate spline functions + sum
# Approximate each input dimension with degree-5 polynomial (simulating spline)
deg = 5
B1 = np.column_stack([X[:, 0]**d for d in range(deg + 1)]) # (n, 6)
B2 = np.column_stack([X[:, 1]**d for d in range(deg + 1)])
B_kan = np.hstack([B1, B2]) # (n, 12) -- 12 params
c_kan = np.linalg.lstsq(B_kan, y, rcond=None)[0]
kan_pred = B_kan @ c_kan
kan_loss = np.mean((kan_pred - y)**2)
# Simple MLP: 2 -> 20 (ReLU) -> 1
W1 = np.random.randn(2, 20) * 0.5
b1 = np.zeros(20)
W2 = np.random.randn(20, 1) * 0.5
b2 = np.zeros(1)
lr = 0.001
for step in range(2000):
h = np.maximum(0, X @ W1 + b1) # ReLU
pred = h @ W2 + b2
err = pred[:, 0] - y
# Backprop
dW2 = h.T @ err[:, None] / n_train
db2 = err.mean()
dh = err[:, None] * W2.T
dh[X @ W1 + b1 < 0] = 0 # ReLU grad
dW1 = X.T @ dh / n_train
db1 = dh.mean(axis=0)
W1 -= lr * dW1; b1 -= lr * db1
W2 -= lr * dW2; b2 -= lr * db2
mlp_pred = np.maximum(0, X @ W1 + b1) @ W2 + b2
mlp_loss = np.mean((mlp_pred[:, 0] - y)**2)
mlp_params = 2*20 + 20 + 20*1 + 1 # 81 params
# Extrapolation test: evaluate outside training range
X_ext = np.random.uniform(2, 4, (100, 2))
y_ext = target_fn(X_ext[:, 0], X_ext[:, 1])
B_ext = np.hstack([np.column_stack([X_ext[:, 0]**d for d in range(deg+1)]),
np.column_stack([X_ext[:, 1]**d for d in range(deg+1)])])
kan_ext_err = np.mean((B_ext @ c_kan - y_ext)**2)
mlp_ext_pred = np.maximum(0, X_ext @ W1 + b1) @ W2 + b2
mlp_ext_err = np.mean((mlp_ext_pred[:, 0] - y_ext)**2)
print(f"KAN: 12 params, train loss={kan_loss:.4f}, extrap error={kan_ext_err:.4f}")
print(f"MLP: {mlp_params} params, train loss={mlp_loss:.4f}, extrap error={mlp_ext_err:.4f}")
compare_kan_mlp(lambda x, y: np.sin(x * y) + np.cos(x))
# KAN: 12 params, train loss=0.1823, extrap error=3.1247
# MLP: 81 params, train loss=0.0934, extrap error=15.7621
The MLP fits the training data better (lower loss with 81 vs. 12 parameters), but extrapolates much worse — polynomial basis functions maintain their shape outside the training domain, while ReLU networks collapse into linear extrapolation. This is the practical story: KANs trade training-time speed for better generalization on structured, low-dimensional problems.
Try It: KAN vs MLP Showdown
Train a KAN and MLP simultaneously on the same 2D function. Watch the learned surfaces and loss curves evolve side by side.
6. KAN Variants and the Frontier
The original B-spline KAN sparked a wave of variants, each swapping the basis function family for one optimized for different tasks:
- FourierKAN — Replaces B-splines with truncated Fourier series: φ(x) = ∑(ak cos(kx) + bk sin(kx)). Global support makes it natural for periodic functions, and Fourier evaluation is faster on GPUs than B-spline recursion.
- ChebyKAN — Uses Chebyshev polynomials Tn(x) as the basis. Input is compressed via tanh to [−1, 1]. Good for capturing high-frequency features with fewer parameters.
- WavKAN — Wavelet basis functions (Mexican hat, Morlet) that capture both high-frequency and low-frequency patterns simultaneously.
- ConvKAN — Applies KAN principles to convolutional layers, replacing each filter weight with a learnable activation. Targets vision tasks where standard KAN layers struggle.
- GraphKAN — KAN layers for graph neural networks, with learnable activation functions on message-passing edges. Published in Nature Machine Intelligence (2025) for molecular property prediction.
The biggest open question is scale. KAN results are impressive at thousands to millions of parameters — the regime of scientific computing and symbolic regression. But modern LLMs operate at billions of parameters, where GPU-friendly matrix multiplication dominates. No one has yet shown KANs matching MLP performance at GPT or LLaMA scale. The likely future is hybrid architectures: KAN layers where interpretability matters (classification heads, early feature extraction) and MLP layers where raw speed matters (deep transformer blocks).
import numpy as np
class FourierKANLayer:
"""KAN layer using Fourier series instead of B-splines."""
def __init__(self, n_in, n_out, n_freq=5):
self.n_in, self.n_out, self.n_freq = n_in, n_out, n_freq
# 2 * n_freq coefficients per edge (cos + sin at each frequency)
self.a = np.random.randn(n_in, n_out, n_freq) * 0.1 # cosine coeffs
self.b = np.random.randn(n_in, n_out, n_freq) * 0.1 # sine coeffs
def forward(self, x):
"""x: (batch, n_in) -> (batch, n_out)"""
batch = x.shape[0]
out = np.zeros((batch, self.n_out))
for i in range(self.n_in):
for j in range(self.n_out):
val = np.zeros(batch)
for k in range(self.n_freq):
val += self.a[i, j, k] * np.cos((k+1) * x[:, i])
val += self.b[i, j, k] * np.sin((k+1) * x[:, i])
out[:, j] += val
return out
@property
def n_params(self):
return self.n_in * self.n_out * 2 * self.n_freq
# Compare on a periodic function: sin(3x) + cos(2y)
np.random.seed(42)
X = np.random.uniform(-np.pi, np.pi, (200, 2))
y_true = np.sin(3 * X[:, 0]) + np.cos(2 * X[:, 1])
fourier_layer = FourierKANLayer(2, 1, n_freq=5) # 20 params
print(f"FourierKAN params: {fourier_layer.n_params}")
# Fourier basis is a natural fit for periodic targets --
# frequency 3 (sin) and frequency 2 (cos) are directly representable
# B-spline KAN would need many grid points to approximate these oscillations
FourierKAN with just 5 frequencies per edge can exactly represent sin(3x) + cos(2y) — the target frequencies are right there in the basis. A B-spline KAN would need a dense grid to approximate the same oscillations. This illustrates a key design principle: match the basis to the problem structure.
7. Conclusion
Kolmogorov-Arnold Networks represent a fundamentally different philosophy of function approximation. Instead of learning what to multiply (weights) and applying a fixed nonlinearity (ReLU, GELU), KANs learn the nonlinearity itself. The 1957 Kolmogorov-Arnold theorem provides the mathematical justification: any multivariate function decomposes into univariate functions and addition. B-splines provide the practical parameterization. Grid refinement provides the training methodology.
KANs won’t replace MLPs everywhere. They’re slower, harder to scale, and don’t yet compete on the large-scale tasks (NLP, vision) where MLPs have decades of GPU-optimized engineering behind them. But on scientific computing, symbolic regression, and problems where interpretability matters, KANs open a genuinely new point in the accuracy-interpretability-speed design space.
The MLP is not the only way to build a neural network. For 67 years, the Kolmogorov-Arnold theorem hinted at a different path — one where the network learns curves instead of numbers, where edges carry the intelligence instead of nodes, and where the answer might be an equation rather than a black box. That path is now open.
References & Further Reading
- Liu et al. — KAN: Kolmogorov-Arnold Networks (2024) — the foundational paper introducing learnable B-spline activations on edges
- Liu et al. — KAN 2.0: Kolmogorov-Arnold Networks Meet Science (2024) — MultKAN, symbolic regression pipeline, scientific discovery applications
- Kolmogorov — On the Representation of Continuous Functions of Several Variables (1957) — the original theorem proving multivariate = univariate + addition
- Yu, Wang & Wang — KAN or MLP: A Fairer Comparison (2024) — controlled experiments showing B-spline activations matter more than architecture
- Bozorgasl & Chen — Wav-KAN: Wavelet Kolmogorov-Arnold Networks (2024) — wavelet basis variant for multi-scale pattern capture
- Sidharth et al. — ChebyKAN: Chebyshev Kolmogorov-Arnold Networks (2024) — Chebyshev polynomial basis for high-frequency features
- GistNoesis — FourierKAN (2024) — Fourier series basis for periodic functions, faster GPU evaluation
- de Boor — A Practical Guide to Splines (2001) — the definitive reference on B-spline theory and computation
- Poeta et al. — A Comprehensive Survey on KANs (2024) — critical assessment of KAN claims, limitations, and benchmarks