Spiking Neural Networks from Scratch
1. Why Spikes?
Your brain runs on 20 watts. GPT-4 training consumed roughly 50 gigawatt-hours. That's a factor of a billion in energy efficiency for many everyday tasks. The difference? Your neurons don't pass floating-point numbers to each other — they communicate through precisely timed electrical pulses called action potentials, or spikes. A neuron either fires or it doesn't, and the information is encoded in when it fires, not how much.
Every neural network we've built in this series — feedforward nets, CNNs, transformers, graph neural networks — uses continuous-valued activations. A neuron computes a weighted sum, passes it through ReLU or softmax, and outputs a real number like 0.73. These are sometimes called "second-generation" neural networks. They're powerful, but they bear almost no resemblance to biological computation.
Spiking Neural Networks (SNNs) are the "third generation." Instead of continuous activations, each neuron maintains a membrane potential that rises with incoming signals and decays over time. When the potential crosses a threshold, the neuron emits a single binary spike and resets. Information travels as spike trains — sequences of precisely timed all-or-nothing events.
This matters for three reasons: energy efficiency (Intel's Loihi 2 neuromorphic chip runs on less than 1 watt, vs. 400 watts for a GPU), temporal computation (spike timing naturally encodes when events happen, not just what they are), and biological plausibility (understanding the brain's actual algorithm). Let's build one from scratch.
2. The Leaky Integrate-and-Fire Neuron
The workhorse of spiking neural networks is the Leaky Integrate-and-Fire (LIF) neuron, first described by Louis Lapicque in 1907. It models the neuron's membrane as an RC circuit: incoming current charges a capacitor (the membrane potential rises), while a resistor continuously leaks charge (the potential decays toward rest).
The governing differential equation is:
τm × dV/dt = −(V − Vrest) + R × I(t)
Where V is the membrane potential, Vrest is the resting potential (typically −65 mV or 0 in normalized form), τm is the membrane time constant (10–20 ms), R is the membrane resistance, and I(t) is the input current. The leak term −(V − Vrest) pulls the potential back toward rest — without input, the neuron relaxes.
The magic happens at the threshold: when V reaches Vth, the neuron fires a spike, resets to Vreset, and enters a brief refractory period (~2 ms) where it cannot fire again. This creates the characteristic sawtooth pattern: charge → threshold → spike → reset → charge again.
We discretize with Euler integration (timestep dt):
import numpy as np
def simulate_lif(current, dt=1.0, tau_m=20.0, v_rest=0.0,
v_thresh=1.0, v_reset=0.0, r=1.0, t_ref=4):
"""Simulate a Leaky Integrate-and-Fire neuron."""
n_steps = len(current)
voltage = np.zeros(n_steps)
spikes = np.zeros(n_steps)
v = v_rest
ref_counter = 0
for t in range(n_steps):
if ref_counter > 0:
ref_counter -= 1
v = v_reset
else:
dv = (-(v - v_rest) + r * current[t]) * dt / tau_m
v += dv
if v >= v_thresh:
spikes[t] = 1.0
v = v_reset
ref_counter = t_ref
voltage[t] = v
return voltage, spikes
# Constant current above threshold
current = np.ones(200) * 1.8
voltage, spikes = simulate_lif(current)
spike_times = np.where(spikes)[0]
print(f"Spike count: {int(spikes.sum())}")
print(f"Firing rate: {spikes.sum() / len(current) * 1000:.0f} Hz")
print(f"Spike times: {spike_times[:8]}...")
With current at 1.8× threshold, the neuron fires regular spikes. Reduce the current below threshold and it charges partway, then leaks back down — no spike. Increase the current and spikes come faster. This is the f-I curve: firing frequency as a function of input current.
3. Encoding Information as Spikes
Standard neural networks eat real numbers for breakfast — a pixel value of 0.73, a temperature of 22.5°C. But spiking networks consume spike trains: sequences of binary events over time. How do you convert 0.73 into spikes?
Rate coding is the simplest approach: the input value becomes the probability of spiking at each timestep. A value of 0.7 means a 70% chance of firing per step. Over many timesteps, the spike count is proportional to the input. This is equivalent to a Poisson process with rate proportional to the input intensity.
Temporal coding is more elegant: stronger inputs fire earlier. A value of 0.9 spikes at timestep 1, while 0.1 spikes at timestep 9. Each neuron fires exactly once — maximally sparse and energy-efficient. This is how your retina works: brighter pixels trigger faster responses.
Population coding uses a bank of neurons with overlapping tuning curves. Each neuron responds maximally to a different input value, like orientation-selective neurons in your visual cortex. The population activity pattern uniquely encodes any input — robust to noise because many neurons vote.
import numpy as np
def rate_encode(values, n_steps=50, seed=42):
"""Rate coding: value = spike probability per timestep."""
rng = np.random.default_rng(seed)
n_neurons = len(values)
spikes = np.zeros((n_steps, n_neurons))
for t in range(n_steps):
spikes[t] = (rng.random(n_neurons) < values).astype(float)
return spikes
def temporal_encode(values, n_steps=50):
"""Temporal coding: stronger input fires earlier."""
n_neurons = len(values)
spikes = np.zeros((n_steps, n_neurons))
for i, v in enumerate(values):
if v > 0.01:
spike_time = int((1 - v) * (n_steps - 1))
spikes[spike_time, i] = 1.0
return spikes
inputs = np.array([0.2, 0.5, 0.9])
rate_spikes = rate_encode(inputs, n_steps=50)
temp_spikes = temporal_encode(inputs, n_steps=50)
for i, val in enumerate(inputs):
r_count = rate_spikes[:, i].sum()
t_time = np.where(temp_spikes[:, i])[0]
print(f"Input {val:.1f}: rate={int(r_count)} spikes/50 steps, "
f"temporal=spike at t={t_time[0] if len(t_time) else 'none'}")
Rate coding is robust but slow (needs many timesteps to represent values accurately). Temporal coding is fast and sparse but sensitive to noise — one jittered spike changes the entire message. In practice, most SNN research uses rate coding for its simplicity, but temporal coding is where the real energy savings live.
4. Building a Spiking Network
A single LIF neuron is interesting, but the real power comes from connecting them. In a spiking network, when a pre-synaptic neuron fires, it sends a weighted current pulse to all post-synaptic neurons it connects to. A positive weight (w > 0) is excitatory — it pushes the receiving neuron toward threshold. A negative weight (w < 0) is inhibitory — it pulls the receiver away from firing.
The key difference from standard neural networks: computation unfolds over T timesteps. At each step: (1) propagate current spikes through weight matrices, (2) integrate the resulting currents into membrane potentials, (3) check thresholds and emit new spikes, (4) apply leak decay. This is not a single matrix multiply — it's a dynamical system evolving over time.
For classification, we decode by counting: the output neuron with the most spikes over all T timesteps wins. This is rate-coded readout — the simplest and most robust decoding strategy.
import numpy as np
def spiking_forward(spike_input, weights, n_steps, beta=0.9,
v_thresh=1.0):
"""Two-layer spiking network: input spikes -> hidden -> output."""
w1, w2 = weights
n_hidden = w1.shape[1]
n_output = w2.shape[1]
v_hid = np.zeros(n_hidden)
v_out = np.zeros(n_output)
out_spikes = np.zeros(n_output)
for t in range(n_steps):
# Layer 1: input spikes drive hidden neurons
i_hid = spike_input[t] @ w1
v_hid = beta * v_hid + i_hid
s_hid = (v_hid >= v_thresh).astype(float)
v_hid = v_hid * (1 - s_hid) # reset
# Layer 2: hidden spikes drive output neurons
i_out = s_hid @ w2
v_out = beta * v_out + i_out
s_out = (v_out >= v_thresh).astype(float)
v_out = v_out * (1 - s_out)
out_spikes += s_out
return out_spikes # spike counts per output neuron
# 4-input, 8-hidden, 3-output network
rng = np.random.default_rng(42)
w1 = rng.normal(0, 0.5, (4, 8))
w2 = rng.normal(0, 0.5, (8, 3))
# Rate-encode some inputs
inputs = np.array([0.8, 0.2, 0.6, 0.4])
spikes_in = (rng.random((50, 4)) < inputs).astype(float)
counts = spiking_forward(spikes_in, [w1, w2], n_steps=50)
print(f"Output spike counts: {counts.astype(int)}")
print(f"Predicted class: {np.argmax(counts)}")
With random weights, the prediction is meaningless — but the architecture works. The network accumulates evidence over time: each input spike nudges the hidden neurons, which nudge the output neurons, which accumulate spike counts like a ballot box. The class with the most votes wins.
Try It: LIF Neuron Lab
5. STDP — Learning with Spike Timing
How does a spiking network learn? One answer comes straight from biology. In 1998, Guo-qiang Bi and Mu-ming Poo recorded from pairs of hippocampal neurons and discovered a remarkable rule: the direction of synaptic change depends on the relative timing of spikes.
If neuron A fires just before neuron B (A's spike might have caused B to fire), the A→B connection strengthens. This is Long-Term Potentiation (LTP) — "neurons that fire together wire together." If A fires just after B (A didn't cause B's spike — anti-causal), the connection weakens. This is Long-Term Depression (LTD).
The Spike-Timing-Dependent Plasticity (STDP) rule quantifies this with an exponential window:
ΔW = A+ × exp(−Δt / τ+) if Δt > 0 (pre before post → strengthen)
ΔW = −A− × exp(Δt / τ−) if Δt < 0 (post before pre → weaken)
Where Δt = tpost − tpre, and typical parameters are τ+ = τ− = 20 ms, A+ ≈ 0.005, A− ≈ 0.005. The effect decays exponentially: spikes 5 ms apart cause large changes, spikes 50 ms apart cause almost none. In practice, A− is slightly larger than A+ to prevent runaway potentiation.
import numpy as np
def stdp_update(pre_times, post_times, w, a_plus=0.005,
a_minus=0.0055, tau_plus=20.0, tau_minus=20.0,
w_max=1.0):
"""Apply STDP rule for all spike pairs."""
for t_pre in pre_times:
for t_post in post_times:
dt = t_post - t_pre
if dt > 0: # pre before post: potentiate
dw = a_plus * np.exp(-dt / tau_plus)
w = min(w + dw, w_max)
elif dt < 0: # post before pre: depress
dw = -a_minus * np.exp(dt / tau_minus)
w = max(w + dw, 0.0)
return w
# Causal pairing: pre fires 5ms before post, repeated 20 times
pre_spikes = np.arange(0, 400, 20).tolist() # every 20ms
post_spikes = np.arange(5, 405, 20).tolist() # 5ms later each time
w = 0.3 # initial weight
weights_over_time = [w]
for i in range(len(pre_spikes)):
w = stdp_update([pre_spikes[i]], [post_spikes[i]], w)
weights_over_time.append(w)
print(f"Weight: {weights_over_time[0]:.3f} -> {weights_over_time[-1]:.3f}")
# Anti-causal: post fires 5ms BEFORE pre
w_anti = 0.3
for i in range(len(pre_spikes)):
w_anti = stdp_update([post_spikes[i]], [pre_spikes[i]], w_anti)
print(f"Anti-causal: {0.3:.3f} -> {w_anti:.3f}")
Causal pairing (pre before post) strengthens the synapse; anti-causal pairing weakens it. STDP is unsupervised — it discovers temporal correlations in spike patterns without any labels. It's how the brain wires itself during development, and it's a plausible mechanism for extracting structure from the world.
Try It: STDP Learning Arena
6. Surrogate Gradients — Making Spikes Differentiable
STDP is biologically beautiful, but it's local and unsupervised — it can't optimize a global loss function across a deep network. We want to train SNNs with backpropagation, just like we train standard networks. There's one problem: the spike function is a Heaviside step — its gradient is zero everywhere except at the threshold, where it's infinite. Backpropagation through spikes produces exactly zero useful gradient.
The breakthrough came from Friedemann Zenke and Surya Ganguli in 2018 (SuperSpike), refined by Neftci, Mostafa, and Zenke in 2019. Their insight is elegant: use exact spikes in the forward pass, but substitute a smooth approximation in the backward pass. The network still fires real binary spikes during inference — we only "lie" about the gradient during training.
The most popular surrogate is the fast sigmoid: instead of the Heaviside derivative (a Dirac delta), we use 1 / (1 + k|u|)2 where u = V − Vth and k controls sharpness. Other options include the arctangent surrogate 1 / (1 + (πku)2) and a Gaussian bell curve centered at the threshold. The remarkable finding: the specific shape barely matters. Training converges with any reasonable smooth surrogate.
import numpy as np
def fast_sigmoid_surrogate(v, v_thresh=1.0, k=25.0):
"""Surrogate gradient: smooth approximation of spike derivative."""
u = v - v_thresh
return 1.0 / (1.0 + k * np.abs(u)) ** 2
def train_snn_step(x, target, w1, w2, beta=0.9, v_thresh=1.0,
lr=0.01, n_steps=25):
"""One training step: forward with spikes, backward with surrogates."""
n_hid = w1.shape[1]
n_out = w2.shape[1]
v_hid, v_out = np.zeros(n_hid), np.zeros(n_out)
out_counts = np.zeros(n_out)
# Collect states for backward pass
hid_spikes_all, hid_v_all, out_v_all = [], [], []
for t in range(n_steps):
# Hidden layer
i_hid = x[t] @ w1
v_hid = beta * v_hid + i_hid
s_hid = (v_hid >= v_thresh).astype(float)
hid_spikes_all.append(s_hid)
hid_v_all.append(v_hid.copy())
v_hid = v_hid * (1 - s_hid) # reset spiked neurons
# Output layer
i_out = s_hid @ w2
v_out = beta * v_out + i_out
s_out = (v_out >= v_thresh).astype(float)
out_v_all.append(v_out.copy())
v_out = v_out * (1 - s_out)
out_counts += s_out
# Loss: MSE on spike counts vs target
loss = 0.5 * np.sum((out_counts - target) ** 2)
# Backward pass with surrogate gradients
d_counts = out_counts - target
dw2 = np.zeros_like(w2)
dw1 = np.zeros_like(w1)
for t in range(n_steps):
# Output surrogate gradient
sg_out = fast_sigmoid_surrogate(out_v_all[t], v_thresh)
d_out = d_counts * sg_out
dw2 += np.outer(hid_spikes_all[t], d_out)
# Hidden surrogate gradient
sg_hid = fast_sigmoid_surrogate(hid_v_all[t], v_thresh)
d_hid = (d_out @ w2.T) * sg_hid
dw1 += np.outer(x[t], d_hid)
w1 -= lr * dw1 / n_steps
w2 -= lr * dw2 / n_steps
return w1, w2, loss
The forward pass fires real binary spikes — (v_hid >= v_thresh).astype(float) is exactly 0 or 1. The backward pass uses fast_sigmoid_surrogate to compute a smooth gradient wherever a spike decision was made. This forward-exact, backward-approximate principle is what unlocked deep SNN training, and it's now the standard approach in libraries like sNNTorch.
7. Neuromorphic Hardware
Training SNNs on GPUs is somewhat ironic — you're simulating event-driven computation on hardware optimized for dense matrix multiplication. Neuromorphic chips are purpose-built silicon that implements spiking dynamics directly in hardware, achieving dramatic energy savings.
| Hardware | Neurons | Synapses | Power | Year |
|---|---|---|---|---|
| NVIDIA A100 GPU | — | — | 400 W | 2020 |
| Intel Loihi 2 | 1M | 120M | < 1 W | 2021 |
| IBM TrueNorth | 1M | 256M | 65–70 mW | 2014 |
| BrainScaleS-2 | 512 | 130K | analog | 2020 |
The numbers are staggering: IBM's TrueNorth runs a million neurons on 65 milliwatts — less than a hearing aid. Intel's Hala Point system (2024) scales to 1,152 Loihi 2 chips with 1.15 billion neurons and achieves 15 TOPS/W without batching. BrainScaleS-2 takes a different approach entirely: analog circuits that emulate LIF dynamics at 1000× biological real-time.
Why so efficient? Three reasons: event-driven (no computation when no spikes arrive — most neurons are silent most of the time), co-located memory and compute (no energy-hungry data movement between CPU and RAM), and asynchronous (no global clock ticking at billions of Hz). An honest caveat: energy savings depend on spike sparsity. At high firing rates (>6% activation), the advantage over GPUs shrinks substantially.
8. When Spiking Beats Conventional
SNNs aren't a universal replacement for transformers and CNNs. They excel in specific niches:
- Event-driven sensors: Dynamic Vision Sensors (DVS cameras) output spikes natively — each pixel fires independently when brightness changes. SNNs process this directly; conventional networks must first bin events into frames, losing temporal precision.
- Ultra-low-power edge: wearable health monitors, always-on keyword detection, IoT sensor hubs. When your power budget is measured in milliwatts, neuromorphic hardware is the only option.
- Temporal patterns: tasks where timing matters — gesture sequences, speech rhythm, robotic control with precise motor timing.
Current benchmark results: MNIST ~99%, CIFAR-10 ~94.5% (deep SNNs with surrogate gradients, T=4 timesteps), DVS gesture recognition ~94%, and keyword spotting ~90%+ on the Spiking Heidelberg Digits dataset.
When not to use SNNs: large-scale language modeling (transformers dominate), dense image classification on static images (CNNs win), or any task where training speed matters more than inference energy. The SNN ecosystem is also less mature — fewer pretrained models, smaller community, less tooling.
9. Full Pipeline
Let's put everything together: an end-to-end spiking neural network that encodes inputs as spike trains, processes them through LIF layers with surrogate gradient training, and classifies by counting output spikes.
import numpy as np
def full_snn_pipeline(x_train, y_train, x_test, y_test,
n_hidden=32, n_steps=30, n_epochs=40,
lr=0.005, beta=0.85, seed=42):
"""Complete SNN: encode, spike, learn, decode."""
rng = np.random.default_rng(seed)
n_in = x_train.shape[1]
n_out = int(y_train.max()) + 1
# Initialize weights
w1 = rng.normal(0, 0.3, (n_in, n_hidden))
w2 = rng.normal(0, 0.3, (n_hidden, n_out))
for epoch in range(n_epochs):
total_loss = 0
for i in range(len(x_train)):
# Step 1: Rate-encode input
inp = np.clip(x_train[i], 0, 1)
spikes_in = (rng.random((n_steps, n_in)) < inp).astype(float)
# Step 2: One-hot target as spike counts
target = np.zeros(n_out)
target[int(y_train[i])] = n_steps * 0.8
# Step 3: Train with surrogate gradients
w1, w2, loss = train_snn_step(
spikes_in, target, w1, w2, beta=beta,
lr=lr, n_steps=n_steps
)
total_loss += loss
# Step 4: Evaluate
if (epoch + 1) % 10 == 0:
correct = 0
for i in range(len(x_test)):
inp = np.clip(x_test[i], 0, 1)
sp = (rng.random((n_steps, n_in)) < inp).astype(float)
counts = spiking_forward(sp, [w1, w2], n_steps,
beta=beta)
if np.argmax(counts) == int(y_test[i]):
correct += 1
acc = correct / len(x_test)
avg_loss = total_loss / len(x_train)
print(f"Epoch {epoch+1}: loss={avg_loss:.3f}, "
f"acc={acc:.1%}")
return w1, w2
# Example: simple 3-class classification
rng = np.random.default_rng(0)
x = rng.random((150, 4))
y = np.array([0]*50 + [1]*50 + [2]*50)
x[:50] += np.array([0.3, 0, 0, 0])
x[50:100] += np.array([0, 0.3, 0, 0])
x[100:] += np.array([0, 0, 0.3, 0])
idx = rng.permutation(150)
x, y = x[idx], y[idx]
w1, w2 = full_snn_pipeline(x[:120], y[:120], x[120:], y[120:])
The pipeline follows the same train/evaluate loop as any neural network — the difference is that computation unfolds through time via spikes rather than a single forward pass. The surrogate gradient trick makes the non-differentiable spike function trainable, while the LIF dynamics provide the temporal computation that makes SNNs unique.
10. Conclusion
Spiking Neural Networks offer a fundamentally different computation paradigm. Where standard networks pass continuous numbers through static layers, SNNs propagate precisely timed binary pulses through dynamical neurons. From Lapicque's 1907 RC circuit model to Intel's Hala Point with 1.15 billion neurons on neuromorphic silicon, the field has come a long way.
The pieces are in place: the LIF neuron provides a simple but expressive computational primitive, STDP shows that biologically plausible learning emerges from spike timing, and surrogate gradients bridge the gap between biological spikes and modern optimization. Neuromorphic hardware turns the theoretical energy efficiency into real silicon running at milliwatts.
Your brain processes speech, recognizes faces, and controls 600+ muscles — all on 20 watts, all with spikes. We're still catching up.
References & Further Reading
- Maass (1997) — Networks of Spiking Neurons: The Third Generation of Neural Network Models — the foundational paper defining SNNs as the third generation of neural computation
- Gerstner & Kistler (2002) — Spiking Neuron Models — comprehensive textbook covering LIF, SRM, and biophysical neuron models
- Bi & Poo (1998) — Synaptic Modifications in Cultured Hippocampal Neurons — the original STDP discovery showing timing-dependent synaptic plasticity
- Zenke & Ganguli (2018) — SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks — the surrogate gradient breakthrough enabling deep SNN training
- Neftci, Mostafa & Zenke (2019) — Surrogate Gradient Learning in Spiking Neural Networks — comprehensive review synthesizing surrogate gradient methods
- Davies et al. (2018) — Loihi: A Neuromorphic Manycore Processor with On-Chip Learning — Intel's neuromorphic chip with 128K neurons and on-chip STDP
- Eshraghian et al. (2023) — Training Spiking Neural Networks Using Lessons From Deep Learning — the sNNTorch tutorial paper bridging deep learning practices and SNN training
- Tavanaei et al. (2019) — Deep Learning in Spiking Neural Networks — survey of SNN architectures, training methods, and applications