SNN论文阅读spikformer
SNN论文阅读——spikformer
Spikformer: When Spiking Neural Network Meets Transformer
- 逻辑与运算和求和代替原本的矩阵乘法,并取缔softmax归一化注意力分数操作
使用LIF模型公式
H[t]=V[t−1]+1τ(X[t]−(V[t−1]−Vreset ))S[t]=Θ(H[t]−Vth)V[t]=H +Vreset S[t] \begin{align*} & H[t]=V[t-1]+\frac{1}{\tau}\left(X[t]-\left(V[t-1]-V_{\text {reset }}\right)\right) \tag{1}\ & S[t]=\Theta\left(H[t]-V_{t h}\right) \tag{2}\ & V[t]=H +V_{\text {reset }} S[t] \tag{3} \end{align*} H[t]=V[t−1]+τ1(X[t]−(V[t−1]−Vreset ))S[t]=Θ(H[t]−Vth)V[t]=H +Vreset S[t](1)(2)(3)
- X[t]X[t]X[t]:时刻 ttt 的输入电流(通常来自前一层 spike 的加权和,可能已做脉冲滤波);
- V[t]V[t]V[t]:时刻 ttt 计算并带入下一步的膜电位(已考虑 reset);
- H[t]H[t]H[t]:触发判断前的候选膜电位(apply leak + add input 后但还未 threshold/reset);
- S[t]∈{0,1}S[t]\in{0,1}S[t]∈{0,1}:是否发 spike(Heaviside);
- τ\tauτ:膜时间常数(离散化后表现为 leak 强度);
- VthV_{th}Vth:阈值;VresetV_{\text{reset}}Vreset:重置电位
常见的带基线 VresetV_{\text{reset}}Vreset 的 LIF 连续形式可以写为(忽略突触核形状,仅写简单的电流项):
τdV(t)dt=−(V(t)−Vreset)+X(t).(C)
\tau \frac{dV(t)}{dt} = -\big(V(t) - V_{\text{reset}}\big) + X(t).
\tag{C}
τdtdV(t)=−(V(t)−Vreset)+X(t).(C)
这是与论文公式一致的物理模型:膜电位相对于基线 VresetV_{\text{reset}}Vreset 指数衰减,受到输入电流 X(t)X(t)X(t) 驱动。等式右边第一个项是“泄漏回到 VresetV_{\text{reset}}Vreset”的项。
注:等价地写成 τV˙=−V+Vreset+X(t)\tau \dot V = -V + V_{\text{reset}} + X(t)τV˙=−V+Vreset+X(t)。
用 Euler 向前差分离散化:
以步长 Δt\Delta tΔt 对 © 用前向(或准显式)欧拉:
V(t+Δt)≈V(t)+Δt 1τ(−(V(t)−Vreset)+X(t+Δt)).
V(t + \Delta t) \approx V(t) + \Delta t,\frac{1}{\tau}\big( - (V(t)-V_{\text{reset}}) + X(t+\Delta t) \big).
V(t+Δt)≈V(t)+Δtτ1(−(V(t)−Vreset)+X(t+Δt)).
若我们把时间索引写为整数步(把 Δt\Delta tΔt 设为 1 个“离散单位”或内含在 τ\tauτ 中),上式变为(把 t→t−1t \to t-1t→t−1 以匹配给定公式的索引):
H[t]=V[t−1]+Δtτ(X[t]−(V[t−1]−Vreset)).
H[t] = V[t-1] + \frac{\Delta t}{\tau}\big( X[t] - (V[t-1] - V_{\text{reset}})\big).
H[t]=V[t−1]+τΔt(X[t]−(V[t−1]−Vreset)).
论文公式里直接写成 1τ\frac{1}{\tau}τ1(等于把 Δt\Delta tΔt 规范化为 1 或把 τ\tauτ 替换为 τ/Δt\tau/\Delta tτ/Δt 的缘故)。
公式 (3):
V[t]={H[t],if S[t]=0 (no spike)Vreset,if S[t]=1 (spike; hard reset)
V[t] = \begin{cases}
H[t], & \text{if } S[t]=0 \ (\text{no spike})\[4pt]
V_{\text{reset}}, & \text{if } S[t]=1 \ (\text{spike; hard reset})
\end{cases}
V[t]={H[t],Vreset,if S[t]=0 (no spike)if S[t]=1 (spike; hard reset)
这是 hard reset to a fixed potential(常见做法)。
核心处理
输入:SPS(Spiking Patch Splitting)——从图像序列到“patch 序列”
输入:一个时序图像张量
I∈RT×C×H×W I \in \mathbb{R}^{T\times C\times H\times W} I∈RT×C×H×W
- TTT :时间步数(帧数或仿真步);
- CCC:通道数(例如 RGB);
- H,WH,WH,W:空间高和宽。
Patch 切分:把每一帧按 patch 大小 (Ph,Pw)(P_h,P_w)(Ph,Pw) 切成网格,令
N=HPh⋅WPw N = \frac{H}{P_h}\cdot\frac{W}{P_w} N=PhH⋅PwW
为每帧的 patch 数(假设能整除)。对第 ttt 帧,第 nnn 个 patch 的像素张量记为 Patcht,n∈RC×Ph×Pw\mathrm{Patch}_{t,n}\in\mathbb{R}^{C\times P_h\times P_w}Patcht,n∈RC×Ph×Pw。
flatten + 线性投影(这是 SPS 的关键):对每个 patch 展平并做线性映射到维度 DDD:
vect,n=vec(Patcht,n)∈RC⋅Ph⋅Pw,xt,n=Wproj vect,n+bproj∈RD, \begin{aligned} \mathrm{vec}{t,n} &= \operatorname{vec}(\mathrm{Patch}{t,n})\in\mathbb{R}^{C\cdot P_h\cdot P_w},\ x_{t,n} &= W_{\mathrm{proj}};\mathrm{vec}{t,n} + b{\mathrm{proj}}\in\mathbb{R}^{D}, \end{aligned} vect,nxt,n=vec(Patcht,n)∈RC⋅Ph⋅Pw,=Wprojvect,n+bproj∈RD,
其中 Wproj∈RD×(CPhPw)W_{\mathrm{proj}}\in\mathbb{R}^{D\times (C P_h P_w)}Wproj∈RD×(CPhPw)。把所有时刻和 patch 拼好得到
x∈RT×N×D. x\in\mathbb{R}^{T\times N\times D}. x∈RT×N×D.
直观:SPS 就是把输入帧按照 patch 做 token 化(像 ViT),并且在每个时间步保留时序维度 TTT,因此得到“spike-form feature sequence” 的形状。这里的 xxx 在 SNN 场景下通常代表“脉冲驱动的输入电流序列”或在时间上可被转成 spike 的实值驱动(后面会用 SN 转为真正 binary spikes)。
注意:这个过程可以用卷积一步到位完成。如果直接写代码,会用循环切 patch + flatten,很低效。其实卷积本质上就是“滑动窗口 + 权重加权”。所以只要让卷积核 大小等于 Patch 大小,步幅 stride 等于 Patch 大小,就能保证卷积核正好对应每个 patch 不重叠扫描一次。
- 卷积核大小:16,1616, 1616,16
- 输入通道数:3(RGB)
- 输出通道数:768(希望每个 patch 变成 768 维)
- 步距 stride:16(保证每次卷积跨一个 patch,不重叠)
卷积的计算:
每个卷积核对应一个 patch 的线性组合,结果就是一个数;有 768 个卷积核,就得到长度为 768 的向量。
条件位置嵌入(Conditional Position Embedding,生成相对位置嵌入 RPE)
论文指出“浮点点位(fixed float)的位置编码不能直接用在 SNN 中”,于是使用一个条件位置嵌入生成器(Conditional PE generator),其模块结构为:Conv2d (k=3) → BN → SN
(其中 SN 是 spike neuron 层)。
Transformer 里 patch 序列本来是展平的,为了做卷积(Conv2d)需要重新拼成空间形式:
x[t]∈RN×D⟶x~[t]∈RD×Hp×Wp
x[t] \in \mathbb{R}^{N \times D} \quad \longrightarrow \quad \tilde{x}[t] \in \mathbb{R}^{D \times H_p \times W_p}
x[t]∈RN×D⟶x~[t]∈RD×Hp×Wp
其中:
- Hp=H/PhH_p = H / P_hHp=H/Ph,
- Wp=W/PwW_p = W / P_wWp=W/Pw,
- N=Hp×WpN = H_p \times W_pN=Hp×Wp。
注意:
- 我们把 D 当作 channel 维度(相当于特征通道数),
- 而 Hp,WpH_p, W_pHp,Wp 是 patch 在空间上的排布。
这就类似于 把 ViT patch embedding 重排回 feature map,再让 CNN 处理。
卷积操作:
对每个时间步单独地做卷积(3×3 kernel,padding=1 以保持大小),输出还是同样的形状:
rt=Conv2d(x[t])∈RD×Hp×Wp.
\tilde{r}_t = \operatorname{Conv2d}(\tilde{x}[t]) \in \mathbb{R}^{D \times H_p \times W_p}.
rt=Conv2d(x[t])∈RD×Hp×Wp.
这样卷积就能编码局部相邻 patch 的信息(相对位置)。
BN 和 Spike Neuron 激活:
卷积输出之后,先做归一化,再通过 LIF Spike Neuron 激活:
rt′=BN(rt),RPEt=SN(rt′).
\tilde{r}_t’ = \operatorname{BN}(\tilde{r}_t), \qquad \mathrm{RPE}_t = \mathcal{SN}(\tilde{r}_t’).
rt′=BN(rt),RPEt=SN(rt′).
这里:
- BN(BatchNorm/LayerNorm)保持数值分布稳定,避免膜电位漂移;
- SN\mathcal{SN}SN 把连续值变成 spike 形式(即二值脉冲或 rate-based 表示)。
这样保证位置编码本身就是 spike-form,可以被 SNN 下游继续处理。
再展平成序列:
最后,把结果 flatten 回 patch 序列形式:
RPEt∈RN×D,RPE∈RT×N×D.
\mathrm{RPE}_t \in \mathbb{R}^{N \times D}, \quad \mathrm{RPE} \in \mathbb{R}^{T \times N \times D}.
RPEt∈RN×D,RPE∈RT×N×D.
这和输入 xxx 的形状保持一致,方便做相加:
X0=x+RPE.
X_0 = x + \mathrm{RPE}.
X0=x+RPE.
这一块可以写作:
RPE=SN(BN(Conv2d(x))),RPE∈RT×N×D.
\operatorname{RPE} = \mathcal{SN}\Big(\operatorname{BN}\big(\operatorname{Conv2d}(x)\big)\Big), \quad \mathrm{RPE} \in \mathbb{R}^{T \times N \times D}.
RPE=SN(BN(Conv2d(x))),RPE∈RT×N×D.
Spiking Self-Attention(SSA)
Spikformer 的做法是:
先把 Q,K,V 都用 BN + SN 二值化(公式 14):
Q=SNQ(BN(XWQ)),K=SNK(BN(XWK)),V=SNV(BN(XWV)), Q=\mathcal{SN}_Q(\mathrm{BN}(XW_Q)),\quad K=\mathcal{SN}_K(\mathrm{BN}(XW_K)),\quad V=\mathcal{SN}_V(\mathrm{BN}(XW_V)), Q=SNQ(BN(XWQ)),K=SNK(BN(XWK)),V=SNV(BN(XWV)),
—— 这里 Q,K,V∈{0,1}T×N×DQ,K,V\in{0,1}^{T\times N\times D}Q,K,V∈{0,1}T×N×D(或者 {0,1}T×N×d{0,1}^{T\times N\times d}{0,1}T×N×d head 维度),每元素为 0/1(脉冲)。
技术备注:SN 可能是逐时间步产生脉冲的 LIF 层,但论文把输出视为二值序列以便后续逻辑运算。
代数含义:当 Q,KQ,KQ,K 是二值时,二者的点积 ⟨qi,kj⟩=∑u=1dqi,ukj,u\langle q_i,k_j\rangle=\sum_{u=1}^d q_{i,u} k_{j,u}⟨qi,kj⟩=∑u=1dqi,ukj,u 等价于“对于 q_iq_iq_i 中为 1 的维度,统计 k_jk_jk_j 在这些维度上为 1 的数量”。把点积看作一种 按位 AND 然后计数。
给出单头 SSA 的中间式:
SSA′(Q,K,V)=SN((QK⊤) V∗s) \operatorname{SSA}’(Q,K,V)=\mathcal{SN}\big( (QK^\top),V * s \big) SSA′(Q,K,V)=SN((QK⊤)V∗s)
然后:
SSA(Q,K,V)=SN(BN(Linear(SSA′(Q,K,V)))). \operatorname{SSA}(Q,K,V)=\mathcal{SN}\big( \operatorname{BN}( \operatorname{Linear}(\operatorname{SSA}’(Q,K,V)) )\big). SSA(Q,K,V)=SN(BN(Linear(SSA′(Q,K,V)))).
在 VSA 里有 1d\frac{1}{\sqrt d}d1 缩放以避免内积随 ddd 变大导致 softmax 梯度消失/不稳定。SSA 中虽然不使用 softmax,但点积规模仍会随 ddd 线性增长:
- 若 Q,KQ,KQ,K 每维为 Bernoulli(ppp),则内积期望约为 p2dp^2 dp2d。当 ddd 很大,计数可能很大,直接进入 SN 会导致大部分都过阈值(导致饱和/信息丢失)。
因此乘以 s≈1/ds\approx 1/ds≈1/d 或其他合适尺度可以把加权和控制在 SN 的敏感区间,从而保证后续 BN + SN 能分辨不同相关度等级。一般可取 s=1/ds = 1/ds=1/d 或 s=c/E[dot]s = c/\mathbb E[\text{dot}]s=c/E[dot] 的经验值。
SSA′(Q,K,V) 是脉冲神经网络(SNN)中的中间输出,通常它的值是稀疏的,并且是一个非线性稀疏表示。SSA’ 的输出虽然给出了注意力信息,但它可能没有经过适当的规范化,也没有调整到目标空间。Linear(SSA′(Q,K,V)) 对 SSA′\operatorname{SSA’}SSA′ 输出的结果进行 线性变换。这通常是通过一个线性层(即一个全连接层)来实现,目的是将 SSA’ 的输出映射到目标空间,可能是不同的维度,或者是为了匹配后续层的要求。