目录

SNN论文阅读spikformer

SNN论文阅读——spikformer

Spikformer: When Spiking Neural Network Meets Transformer

  • 逻辑与运算和求和代替原本的矩阵乘法,并取缔softmax归一化注意力分数操作

使用LIF模型公式

H[t]=V[t−1]+1τ(X[t]−(V[t−1]−Vreset ))S[t]=Θ(H[t]−Vth)V[t]=H +Vreset S[t] \begin{align*} & H[t]=V[t-1]+\frac{1}{\tau}\left(X[t]-\left(V[t-1]-V_{\text {reset }}\right)\right) \tag{1}\ & S[t]=\Theta\left(H[t]-V_{t h}\right) \tag{2}\ & V[t]=H +V_{\text {reset }} S[t] \tag{3} \end{align*} ​H[t]=V[t−1]+τ1​(X[t]−(V[t−1]−Vreset ​))S[t]=Θ(H[t]−Vth​)V[t]=H +Vreset ​S[t]​(1)(2)(3)​

  • X[t]X[t]X[t]:时刻 ttt 的输入电流(通常来自前一层 spike 的加权和,可能已做脉冲滤波);
  • V[t]V[t]V[t]:时刻 ttt 计算并带入下一步的膜电位(已考虑 reset);
  • H[t]H[t]H[t]:触发判断前的候选膜电位(apply leak + add input 后但还未 threshold/reset);
  • S[t]∈{0,1}S[t]\in{0,1}S[t]∈{0,1}:是否发 spike(Heaviside);
  • τ\tauτ:膜时间常数(离散化后表现为 leak 强度);
  • VthV_{th}Vth​:阈值;VresetV_{\text{reset}}Vreset​:重置电位

常见的带基线 VresetV_{\text{reset}}Vreset​ 的 LIF 连续形式可以写为(忽略突触核形状,仅写简单的电流项):
τdV(t)dt=−(V(t)−Vreset)+X(t).(C) \tau \frac{dV(t)}{dt} = -\big(V(t) - V_{\text{reset}}\big) + X(t). \tag{C} τdtdV(t)​=−(V(t)−Vreset​)+X(t).(C)
这是与论文公式一致的物理模型:膜电位相对于基线 VresetV_{\text{reset}}Vreset​ 指数衰减,受到输入电流 X(t)X(t)X(t) 驱动。等式右边第一个项是“泄漏回到 VresetV_{\text{reset}}Vreset​”的项。

注:等价地写成 τV˙=−V+Vreset+X(t)\tau \dot V = -V + V_{\text{reset}} + X(t)τV˙=−V+Vreset​+X(t)。

用 Euler 向前差分离散化:

以步长 Δt\Delta tΔt 对 © 用前向(或准显式)欧拉:
V(t+Δt)≈V(t)+Δt 1τ(−(V(t)−Vreset)+X(t+Δt)). V(t + \Delta t) \approx V(t) + \Delta t,\frac{1}{\tau}\big( - (V(t)-V_{\text{reset}}) + X(t+\Delta t) \big). V(t+Δt)≈V(t)+Δtτ1​(−(V(t)−Vreset​)+X(t+Δt)).
若我们把时间索引写为整数步(把 Δt\Delta tΔt 设为 1 个“离散单位”或内含在 τ\tauτ 中),上式变为(把 t→t−1t \to t-1t→t−1 以匹配给定公式的索引):
H[t]=V[t−1]+Δtτ(X[t]−(V[t−1]−Vreset)). H[t] = V[t-1] + \frac{\Delta t}{\tau}\big( X[t] - (V[t-1] - V_{\text{reset}})\big). H[t]=V[t−1]+τΔt​(X[t]−(V[t−1]−Vreset​)).
论文公式里直接写成 1τ\frac{1}{\tau}τ1​(等于把 Δt\Delta tΔt 规范化为 1 或把 τ\tauτ 替换为 τ/Δt\tau/\Delta tτ/Δt 的缘故)。

公式 (3):
V[t]={H[t],if S[t]=0 (no spike)Vreset,if S[t]=1 (spike; hard reset) V[t] = \begin{cases} H[t], & \text{if } S[t]=0 \ (\text{no spike})\[4pt] V_{\text{reset}}, & \text{if } S[t]=1 \ (\text{spike; hard reset}) \end{cases} V[t]={H[t],Vreset​,​if S[t]=0 (no spike)if S[t]=1 (spike; hard reset)​
这是 hard reset to a fixed potential(常见做法)。

核心处理

输入:SPS(Spiking Patch Splitting)——从图像序列到“patch 序列”

输入:一个时序图像张量

I∈RT×C×H×W I \in \mathbb{R}^{T\times C\times H\times W} I∈RT×C×H×W

  • TTT :时间步数(帧数或仿真步);
  • CCC:通道数(例如 RGB);
  • H,WH,WH,W:空间高和宽。

Patch 切分:把每一帧按 patch 大小 (Ph,Pw)(P_h,P_w)(Ph​,Pw​) 切成网格,令

N=HPh⋅WPw N = \frac{H}{P_h}\cdot\frac{W}{P_w} N=Ph​H​⋅Pw​W​

为每帧的 patch 数(假设能整除)。对第 ttt 帧,第 nnn 个 patch 的像素张量记为 Patcht,n∈RC×Ph×Pw\mathrm{Patch}_{t,n}\in\mathbb{R}^{C\times P_h\times P_w}Patcht,n​∈RC×Ph​×Pw​。

flatten + 线性投影(这是 SPS 的关键):对每个 patch 展平并做线性映射到维度 DDD:

vect,n=vec⁡(Patcht,n)∈RC⋅Ph⋅Pw,xt,n=Wproj  vect,n+bproj∈RD, \begin{aligned} \mathrm{vec}{t,n} &= \operatorname{vec}(\mathrm{Patch}{t,n})\in\mathbb{R}^{C\cdot P_h\cdot P_w},\ x_{t,n} &= W_{\mathrm{proj}};\mathrm{vec}{t,n} + b{\mathrm{proj}}\in\mathbb{R}^{D}, \end{aligned} vect,n​xt,n​​=vec(Patcht,n​)∈RC⋅Ph​⋅Pw​,=Wproj​vect,n​+bproj​∈RD,​

其中 Wproj∈RD×(CPhPw)W_{\mathrm{proj}}\in\mathbb{R}^{D\times (C P_h P_w)}Wproj​∈RD×(CPh​Pw​)。把所有时刻和 patch 拼好得到

x∈RT×N×D. x\in\mathbb{R}^{T\times N\times D}. x∈RT×N×D.

直观:SPS 就是把输入帧按照 patch 做 token 化(像 ViT),并且在每个时间步保留时序维度 TTT,因此得到“spike-form feature sequence” 的形状。这里的 xxx 在 SNN 场景下通常代表“脉冲驱动的输入电流序列”或在时间上可被转成 spike 的实值驱动(后面会用 SN 转为真正 binary spikes)。

注意:这个过程可以用卷积一步到位完成。如果直接写代码,会用循环切 patch + flatten,很低效。其实卷积本质上就是“滑动窗口 + 权重加权”。所以只要让卷积核 大小等于 Patch 大小步幅 stride 等于 Patch 大小,就能保证卷积核正好对应每个 patch 不重叠扫描一次。

  • 卷积核大小:16,1616, 1616,16
  • 输入通道数:3(RGB)
  • 输出通道数:768(希望每个 patch 变成 768 维)
  • 步距 stride:16(保证每次卷积跨一个 patch,不重叠)

卷积的计算:
每个卷积核对应一个 patch 的线性组合,结果就是一个数;有 768 个卷积核,就得到长度为 768 的向量。

条件位置嵌入(Conditional Position Embedding,生成相对位置嵌入 RPE)

论文指出“浮点点位(fixed float)的位置编码不能直接用在 SNN 中”,于是使用一个条件位置嵌入生成器(Conditional PE generator),其模块结构为:Conv2d (k=3) → BN → SN(其中 SN 是 spike neuron 层)。

Transformer 里 patch 序列本来是展平的,为了做卷积(Conv2d)需要重新拼成空间形式
x[t]∈RN×D⟶x~[t]∈RD×Hp×Wp x[t] \in \mathbb{R}^{N \times D} \quad \longrightarrow \quad \tilde{x}[t] \in \mathbb{R}^{D \times H_p \times W_p} x[t]∈RN×D⟶x~[t]∈RD×Hp​×Wp​
其中:

  • Hp=H/PhH_p = H / P_hHp​=H/Ph​,
  • Wp=W/PwW_p = W / P_wWp​=W/Pw​,
  • N=Hp×WpN = H_p \times W_pN=Hp​×Wp​。

注意:

  • 我们把 D 当作 channel 维度(相当于特征通道数),
  • 而 Hp,WpH_p, W_pHp​,Wp​ 是 patch 在空间上的排布。

这就类似于 把 ViT patch embedding 重排回 feature map,再让 CNN 处理。

卷积操作:

对每个时间步单独地做卷积(3×3 kernel,padding=1 以保持大小),输出还是同样的形状:
rt=Conv2d⁡(x[t])∈RD×Hp×Wp. \tilde{r}_t = \operatorname{Conv2d}(\tilde{x}[t]) \in \mathbb{R}^{D \times H_p \times W_p}. rt​=Conv2d(x[t])∈RD×Hp​×Wp​.
这样卷积就能编码局部相邻 patch 的信息(相对位置)。

BN 和 Spike Neuron 激活:

卷积输出之后,先做归一化,再通过 LIF Spike Neuron 激活:
rt′=BN⁡(rt),RPEt=SN(rt′). \tilde{r}_t’ = \operatorname{BN}(\tilde{r}_t), \qquad \mathrm{RPE}_t = \mathcal{SN}(\tilde{r}_t’). rt′​=BN(rt​),RPEt​=SN(rt′​).
这里:

  • BN(BatchNorm/LayerNorm)保持数值分布稳定,避免膜电位漂移;
  • SN\mathcal{SN}SN 把连续值变成 spike 形式(即二值脉冲或 rate-based 表示)。

这样保证位置编码本身就是 spike-form,可以被 SNN 下游继续处理。

再展平成序列:

最后,把结果 flatten 回 patch 序列形式:
RPEt∈RN×D,RPE∈RT×N×D. \mathrm{RPE}_t \in \mathbb{R}^{N \times D}, \quad \mathrm{RPE} \in \mathbb{R}^{T \times N \times D}. RPEt​∈RN×D,RPE∈RT×N×D.
这和输入 xxx 的形状保持一致,方便做相加:
X0=x+RPE. X_0 = x + \mathrm{RPE}. X0​=x+RPE.
这一块可以写作:
RPE⁡=SN(BN⁡(Conv2d⁡(x))),RPE∈RT×N×D. \operatorname{RPE} = \mathcal{SN}\Big(\operatorname{BN}\big(\operatorname{Conv2d}(x)\big)\Big), \quad \mathrm{RPE} \in \mathbb{R}^{T \times N \times D}. RPE=SN(BN(Conv2d(x))),RPE∈RT×N×D.

Spiking Self-Attention(SSA)

Spikformer 的做法是:

  • 先把 Q,K,V 都用 BN + SN 二值化(公式 14):

    Q=SNQ(BN(XWQ)),K=SNK(BN(XWK)),V=SNV(BN(XWV)), Q=\mathcal{SN}_Q(\mathrm{BN}(XW_Q)),\quad K=\mathcal{SN}_K(\mathrm{BN}(XW_K)),\quad V=\mathcal{SN}_V(\mathrm{BN}(XW_V)), Q=SNQ​(BN(XWQ​)),K=SNK​(BN(XWK​)),V=SNV​(BN(XWV​)),

    —— 这里 Q,K,V∈{0,1}T×N×DQ,K,V\in{0,1}^{T\times N\times D}Q,K,V∈{0,1}T×N×D(或者 {0,1}T×N×d{0,1}^{T\times N\times d}{0,1}T×N×d head 维度),每元素为 0/1(脉冲)。
    技术备注:SN 可能是逐时间步产生脉冲的 LIF 层,但论文把输出视为二值序列以便后续逻辑运算。

代数含义:当 Q,KQ,KQ,K 是二值时,二者的点积 ⟨qi,kj⟩=∑u=1dqi,ukj,u\langle q_i,k_j\rangle=\sum_{u=1}^d q_{i,u} k_{j,u}⟨qi​,kj​⟩=∑u=1d​qi,u​kj,u​ 等价于“对于 q_iq_iq_i 中为 1 的维度,统计 k_jk_jk_j 在这些维度上为 1 的数量”。把点积看作一种 按位 AND 然后计数

给出单头 SSA 的中间式:

SSA⁡′(Q,K,V)=SN((QK⊤) V∗s) \operatorname{SSA}’(Q,K,V)=\mathcal{SN}\big( (QK^\top),V * s \big) SSA′(Q,K,V)=SN((QK⊤)V∗s)

然后:

SSA⁡(Q,K,V)=SN(BN⁡(Linear⁡(SSA⁡′(Q,K,V)))). \operatorname{SSA}(Q,K,V)=\mathcal{SN}\big( \operatorname{BN}( \operatorname{Linear}(\operatorname{SSA}’(Q,K,V)) )\big). SSA(Q,K,V)=SN(BN(Linear(SSA′(Q,K,V)))).

  • 在 VSA 里有 1d\frac{1}{\sqrt d}d​1​ 缩放以避免内积随 ddd 变大导致 softmax 梯度消失/不稳定。SSA 中虽然不使用 softmax,但点积规模仍会随 ddd 线性增长:

    • 若 Q,KQ,KQ,K 每维为 Bernoulli(ppp),则内积期望约为 p2dp^2 dp2d。当 ddd 很大,计数可能很大,直接进入 SN 会导致大部分都过阈值(导致饱和/信息丢失)。
  • 因此乘以 s≈1/ds\approx 1/ds≈1/d 或其他合适尺度可以把加权和控制在 SN 的敏感区间,从而保证后续 BN + SN 能分辨不同相关度等级。一般可取 s=1/ds = 1/ds=1/d 或 s=c/E[dot]s = c/\mathbb E[\text{dot}]s=c/E[dot] 的经验值。

  • SSA′(Q,K,V) 是脉冲神经网络(SNN)中的中间输出,通常它的值是稀疏的,并且是一个非线性稀疏表示。SSA’ 的输出虽然给出了注意力信息,但它可能没有经过适当的规范化,也没有调整到目标空间。Linear(SSA′(Q,K,V)) 对 SSA′⁡\operatorname{SSA’}SSA′ 输出的结果进行 线性变换。这通常是通过一个线性层(即一个全连接层)来实现,目的是将 SSA’ 的输出映射到目标空间,可能是不同的维度,或者是为了匹配后续层的要求。