PyTorch生成式人工智能30扩散模型Diffusion-Model
PyTorch生成式人工智能(30)——扩散模型(Diffusion Model)
0. 前言
文本生成图像 (text-to-image
) Transformer
模型,如 OpenAI
的 DALL-E 2
、Stability AI
的 Stable Diffusion
,能够根据文本描述生成高质量的图像。这些文本生成图像的模型包含三个核心组成部分:一个文本编码器,将文本压缩成潜表示;一个将文本信息融入图像生成过程的方法;以及一个扩散机制,逐步优化图像以产生逼真的输出。理解扩散机制对于理解文本生成图像 Transformer
尤其重要,因为扩散模型构成了所有主流文本生成图像 Transformer
的基础。因此,本节将构建并训练一个扩散模型来生成花卉图像,以深入理解正向扩散过程,其中噪声逐步添加到图像中,直到图像变成随机噪声。随后,将训练一个模型来逆向扩散过程,通过逐步去除噪声,直到模型能够从随机噪声中生成一张清晰的图像。
扩散模型已成为生成高分辨率图像的首选方法。扩散模型的成功在于其能够模拟并逆转复杂的噪声添加过程,这一过程模仿了如何从抽象模式中构建图像。这种方法不仅确保了生成高质量图像,还在生成图像的多样性和准确性之间保持了平衡。
1. 去噪扩散模型简介
假设目标是使用扩散模型生成高分辨率的花卉图像。为了实现这一目标,首先需要收集一组高质量的花卉图像进行训练。然后,模型逐步向这些图像中添加少量随机噪声,这一过程称为正向扩散 (forward diffusion
)。经过多次加入噪声的步骤后,训练图像最终变成了随机噪声。接下来,训练模型逆转这一过程,从纯噪声图像开始,逐步减少噪声,直到图像与原始训练集中的图像无法区分。
在本节中,首先介绍扩散模型的数学基础。然后,深入了解 U-Net
的架构,用于去噪图像并生成高分辨率花卉图像。最后,将了解扩散模型的训练过程以及训练后的模型生成图像的过程。
1.1 正向扩散过程
本节,我们使用花卉图像作为一个具体的例子来解释去噪扩散模型的思想,下图展示了正向扩散过程的工作原理。
假设花卉图像 x 0 x_0 x0 (如上图左侧所示)遵循分布 q ( x ) q(x) q(x)。在正向扩散过程中,我们将在 T
1000
T = 1000
T=1000 步中逐步向图像中添加少量噪声。噪声张量服从正态分布,并且具有与花卉图像相同的形状。在扩散模型中,时间步指的是逐步向数据中添加噪声并随后反转这一过程以生成样本时的离散阶段。扩散模型的正向过程在多个时间步中逐渐向数据中添加噪声,将数据从原始的干净状态转变为噪声分布。在逆向过程阶段,模型以相反的顺序操作,系统地从数据中去除噪声,以重建原始数据或生成新的高保真样本。逆向过程中的每个时间步都涉及预测在相应的正向步骤中添加的噪声并将其减去,从而逐渐去噪,直到数据恢复到干净状态。
在第 1
个时间步,向图像
x
0
x_0
x0 添加噪声
ε
0
\varepsilon_0
ε0,从而得到一个带噪声的图像
x
1
x_1
x1:
x 1
1 − β 1 ⋅ x 0 + β 1 ⋅ ε 0 x_1 = \sqrt{1 - \beta_1} \cdot x_0 + \sqrt{\beta_1} \cdot \varepsilon_0 x1=1−β1
⋅x0+β1
⋅ε0
也就是说,
x
1
x_1
x1 是
x
0
x_0
x0 和
ε
0
\varepsilon_0
ε0 的加权和,其中
β
1
\beta_1
β1 表示噪声的权重。
β
\beta
β 的值在不同的时间步中会有所变化,因此下标表示不同的时间步。如果我们假设
x
0
x_0
x0 和
ε
0
\varepsilon_0
ε0 相互独立并且服从标准正态分布(即均值为 0
,方差为 1
),那么噪声图像
x
1
x_1
x1 也将服从标准正态分布。
我们可以继续在接下来的
T
−
1
T-1
T−1 个时间步中向图像添加噪声,使得:
x t + 1
1 − β t + 1 ⋅ x t + β t + 1 ⋅ ε t x_{t+1} = \sqrt{1 - \beta_{t+1}} \cdot x_t + \sqrt{\beta_{t+1}} \cdot \varepsilon_t xt+1=1−βt+1
⋅xt+βt+1
⋅εt
我们可以使用重参数化技巧,定义
α
t
1 − β t \alpha_t = 1 - \beta_t αt=1−βt,并有:
α ˉ t
∏ k
1
t
α
k
\bar {\alpha}t=\prod{k=1}^t\alpha_k
αˉt=k=1∏tαk
这使得我们可以在任意时间步
t
t
t 处采样
x
t
x_t
xt,其中
t
t
t 可以取值于
[
1
,
2
,
…
,
T
−
1
,
T
]
[1, 2, \dots, T-1, T]
[1,2,…,T−1,T]。然后有:
x t
α t ⋅ x 0 + 1 − α t ⋅ ε t x_t = \sqrt{\alpha_t} \cdot x_0 + \sqrt{1 - \alpha_t} \cdot \varepsilon_t xt=αt
⋅x0+1−αt
⋅εt
其中,
ε
\varepsilon
ε 是
ε
0
\varepsilon_0
ε0,
ε
1
\varepsilon_1
ε1, …,
ε
t
−
1
\varepsilon_{t-1}
εt−1 的组合,使用了可以将两个正态分布相加来得到一个新的正态分布的性质。
上图最左侧展示了训练集中的一张干净花卉图像
x
0
x_0
x0。在第一个时间步,我们向其注入噪声
ε
0
\varepsilon_0
ε0 形成带噪声的图像
x
1
x_1
x1 (上图中的第二张图像)。我们重复这一过程 1000
个时间步,直到图像变成随机噪声(最右侧的图像)。
1.2 逆向扩散过程
我们已经了解了前向扩散过程,接下来我们介绍逆向扩散过程(即去噪过程)。如果我们能够训练一个模型来逆转前向扩散过程,我们就可以将随机噪声输入模型,并让模型生成一张噪声较大的花卉图像。接着,可以将这张噪声图像再次输入训练好的模型,生成一张更清晰但仍然带有噪声的图像。反复进行这一过程,经过多个时间步,直到得到一张与训练集中图像无法区分的干净图像。逆向扩散过程中使用多个推理步骤,对于从噪声分布中逐步重建高质量数据至关重要,使得数据的生成更加可控、稳定且高质量。
为此,我们将创建一个去噪 U-Net
模型。U-Net
架构最初是为生物医学图像分割设计的,其特点是具有对称形状,包括一个收缩路径(编码器)和一个扩张路径(解码器),两者通过瓶颈层连接。在去噪任务中,U-Net
模型被调整为从图像中去除噪声的同时保留重要的细节。由于其能够高效捕捉图像的局部和全局特征,U-Net
在去噪任务中优于简单的卷积网络。下图展示了本节中我们使用的去噪 U-Net
的结构。
该模型以噪声图像及其所在的时间步(即公式中的
x
t
x_t
xt 和
t
t
t )为输入,预测图像中的噪声(即
ε
\varepsilon
ε)。由于噪声图像是原始干净图像和噪声的加权和,得到噪声后,我们可以推断并重建原始图像。
收缩路径(即编码器,下图左侧部分)由多个卷积层和池化层组成。它逐步对图像进行下采样,提取并编码不同抽象层次的特征。网络的这一部分学习识别与去噪相关的模式和特征。
瓶颈层(下图底部)连接编码器和解码器路径。它由卷积层组成,负责捕捉图像的最抽象表示。
扩张路径(即解码器,下图右侧部分)由上采样层和卷积层组成。它逐步上采样特征图,同时通过跳跃连接结合编码器的特征来重建图像。
跳跃连接(下图中由虚线表示)在 U-Net
模型中至关重要,因为它允许模型通过结合低级和高级特征来保留输入图像中的细粒度细节。接下来,我们简要解释跳跃连接的工作原理。
在 U-Net
模型中,跳跃连接通过将编码器路径中的特征图与解码器路径中相应的特征图进行拼接来实现。这些特征图通常具有相同的空间维度,但由于它们各自经过了不同的路径处理,可能已经有所不同。在编码过程中,输入图像会逐步下采样,导致一些空间信息(如边缘和纹理)丢失。跳跃连接有助于通过将编码器中的特征图直接传递到解码器,绕过信息瓶颈,从而保留这些信息。
通过将解码器中的高层次抽象特征与编码器中的低层次细节特征相结合,跳跃连接使得模型能够更好地重建去噪图像中的细粒度细节。这在去噪任务中尤为重要,因为保留细微的图像细节是至关重要的。
在去噪 U-Net
模型中,
机制在收缩路径和扩张路径的最后一个块中实现,并伴随有层归一化和残差连接(如上图中标注的 Attn/Norm/Add
部分)。
跳跃连接和模型的规模导致去噪 U-Net
中存在冗余的特征提取,确保在去噪过程中不会丢失任何重要特征。但模型规模的庞大也使得相关特征的识别变得更加复杂,注意力机制使得模型能够强调重要特征,同时忽略不相关的特征,从而增强了学习过程的有效性。
1.3 训练去噪 U-Net 模型流程
去噪 U-Net
的输出是注入到噪声图像中的噪声。模型的训练目标是最小化输出(预测噪声)与实际噪声(真实噪声)之间的差异。
去噪 U-Net
模型利用 U-Net
架构捕捉局部和全局上下文的能力,使其在去除噪声的同时保留重要细节(如边缘和纹理)。去噪 U-Net
模型广泛应用于各种任务,包括医学图像去噪、摄影图像修复等。下图展示了去噪 U-Net 模型的训练过程。
我们使用 Oxford 102 Flower
数据集作为训练集。我们将所有图像调整为固定分辨率 64 × 64
像素,并将像素值归一化到 [-1, 1]
范围。为了进行去噪,我们需要一对干净图像和噪声图像。将噪声添加到干净的花卉图像中,从而创建噪声图像。
接下来,构建一个去噪 U-Net
模型。在每个训练 epoch
中,按批次遍历数据集。将噪声添加到花卉图像中,并将噪声图像及其所在的时间步
t
t
t 输入 U-Net
模型预测噪声。
将预测的噪声与实际噪声进行比较,并计算像素级的 L1
损失(即平均绝对误差)。然后,调整模型参数以最小化 L1
损失。多次重复这一过程,直到模型收敛。
2. 数据处理
我们将使用 Oxford 102
花卉数据集作为训练数据,该数据集包含大约 8,000
张花卉图像,可以通过 datasets
库直接下载。本节将大部分辅助函数和类放在模块 util.py
和 unet_util.py
中,以专注于扩散模型,可以直接在
。
在本节中,需要使用 datasets
、einops
、diffusers
和 openai
库。首先,使用 pip
命令进行安装:
$ pip install datasets einops diffusers openai
2.1 使用花卉图像作为训练数据
(1) 使用 datasets
库中的 load_dataset()
方法下载 Oxford 102
花卉数据集。然后,可视化数据集中的一些花卉图像,以便了解训练数据集中的图像:
from datasets import load_dataset
from util import transforms
# 下载数据集
dataset = load_dataset("huggan/flowers-102-categories",
split="train",)
dataset.set_transform(transforms)
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
# 可视化数据样本
grid = make_grid(dataset[:16]["input"], 8, 2)
plt.figure(figsize=(8,2),dpi=300)
plt.imshow(grid.numpy().transpose((1,2,0)))
plt.axis("off")
plt.show()
(2) 将数据集按每批次 4
张图像进行分组,用于来训练去噪 U-Net
模型:
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
resolution=64
batch_size=4
train_dataloader=torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True)
2.2 可视化前向扩散过程
使用模块 util.py
中定义的 DDIMScheduler()
类,使用它向图像添加噪声。之后,还将使用该类与训练好的去噪 U-Net
模型一起生成干净的图像。DDIMScheduler()
类管理去噪时间步和顺序,从而实现确定性推理,能够通过去噪过程生成高质量的样本。
(1) 首先从训练集中选择四张干净图像,并生成与这些图像形状相同的噪声张量:
# 获取四张干净图像
clean_images=next(iter(train_dataloader))["input"]*2-1
print(clean_images.shape)
nums=clean_images.shape[0]
# 生成张量 noise,其形状与干净图像相同;noise 中的每个值都遵循独立的标准正态分布
noise=torch.randn(clean_images.shape)
print(noise.shape)
输出结果如下所示:
torch.Size([2, 3, 64, 64])
torch.Size([2, 3, 64, 64])
(2) 生成并可视化一些过渡噪声图像:
from util import DDIMScheduler
# 实例化 DDIMScheduler() 类,并设置 1,000 个时间步
noise_scheduler=DDIMScheduler(num_train_timesteps=1000)
allimgs=clean_images
# 查看时间步 200、400、600、800 和 1,000
for step in range(200,1001,200):
# 创建噪声图像
timesteps=torch.tensor([step-1]*4).long()
noisy_images=noise_scheduler.add_noise(clean_images,
noise, timesteps)
# 连接噪声图像与干净图像
allimgs=torch.cat((allimgs,noisy_images))
# 可视化图像
import torchvision
imgs=torchvision.utils.make_grid(allimgs,4,6)
fig = plt.figure(dpi=300)
plt.imshow((imgs.permute(2,1,0)+1)/2)
plt.axis("off")
plt.show()
DDIMScheduler()
类中的 add_noise()
方法有三个参数:clean_images
、noise
和 timesteps
,用于生成干净图像和噪声的加权和,即噪声图像。此外,权重是时间步
t
t
t 的函数。随着时间步
t
t
t 从 0
变化到 1000
,干净图像的权重减少,而噪声的权重增加。运行以上代码,生成结果如下所示。
第一列包含四张没有噪声的干净图像,随着我们逐渐向图像中添加越来越多的噪声,最后一列成为纯随机噪声。
3. 构建去噪 U-Net 模型
我们已经介绍了去噪 U-Net
模型的架构,在本节中,我们将使用 PyTorch
实现去噪 U-Net
模型。U-Net
模型设计用于通过下采样和上采样输入图像的过程,捕捉图像中的局部和全局特征。模型使用多个卷积层,通过跳跃连接将不同层次的特征结合起来。这种架构有助于保持空间信息,从而促进更有效的学习。
由于去噪 U-Net
模型的规模庞大,并且特征提取存在冗余,因此采用缩放点积注意力机制 (Self-Distilled Pixel-wise Attention
, SDPA
),使模型能够集中处理输入图像中最相关的部分。为了计算 SDPA
注意力,我们将图像展平并将其像素视为一个序列。然后,使用 SDPA
来学习图像中不同像素之间的依赖关系,类似于
中学习文本中不同词元之间的依赖关系。
3.1 去噪 U-Net 模型中的注意力机制
(1) 为了实现注意力机制,使用在模块 util.py
中定义的 Attention()
类:
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
# 将输入通过三个线性层传递,以获取查询(query)、键(key)和值(value)
qkv = self.to_qkv(x).chunk(3, dim=1)
# 将查询(query)、键(key)和值(value)拆分成四个头(heads)
q, k, v = map(
lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads),
qkv)
q = q * self.scale
sim = einsum('b h d i, b h d j -> b h i j', q, k)
# 计算注意力权重
attn = sim.softmax(dim=-1)
# 计算注意力向量
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
# 合并注意力向量
return self.to_out(out)
attn=Attention(128)
x=torch.rand(1,128,64,64)
out=attn(x)
print(out.shape)
输出结果如下:
torch.Size([1, 128, 64, 64])
为了演示 SDPA
在本节中的操作,我们创建了一张假设的图像 x
,其尺寸为 (1, 128, 64, 64)
,表示批次中的一张图像,128
个特征通道,每个通道的大小为 64 × 64
像素。输入 x
然后通过注意力层进行处理。具体来说,图像中的每个特征通道被展平为一个 64 × 64 = 4096
像素的序列。这个序列通过三个不同的神经网络层生成查询 Q
、键 K
和值 V
,随后将它们分成 4
个注意力头,每个注意力头中的注意力向量计算方式如下:
Attention ( Q , K , V )
softmax ( Q ⋅ K T d k ) ⋅ V \text{Attention} (Q,K,V)= \text{softmax}(\frac{Q \cdot K^T}{\sqrt{d_k}})\cdot V Attention(Q,K,V)=softmax(dk
Q⋅KT)⋅V
其中,
d
k
d_k
dk 表示键向量
K
K
K 的维度。四个头的注意力向量被拼接回一个单一的注意力向量。
3.2 去噪 U-Net 模型
(1) 在模块 unet_util.py
中,我们定义了 UNet()
类来表示去噪 U-Net
模型:
class UNet(nn.Module):
...
def forward(self, sample, timesteps):
# 模型接收一批噪声图像和时间步长作为输入
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps],
dtype=torch.long,
device=sample.device)
timesteps = torch.flatten(timesteps)
timesteps = timesteps.broadcast_to(sample.shape[0])
t_emb = sinusoidal_embedding(timesteps, self.hidden_dims[0])
# 嵌入的时间步长会在不同阶段作为输入加到图像中
t_emb = self.time_embedding(t_emb)
x = self.init_conv(sample)
r = x.clone()
skips = []
# 将输入传递通过收缩路径
for block1, block2, attn, downsample in self.down_blocks:
x = block1(x, t_emb)
skips.append(x)
x = block2(x, t_emb)
x = attn(x)
skips.append(x)
x = downsample(x)
# 将输入传递通过瓶颈路径
x = self.mid_block1(x, t_emb)
x = self.mid_attn(x)
x = self.mid_block2(x, t_emb)
# 将输入传递通过扩张路径,并带有跳跃连接
for block1, block2, attn, upsample in self.up_blocks:
x = torch.cat((x, skips.pop()), dim=1)
x = block1(x, t_emb)
x = torch.cat((x, skips.pop()), dim=1)
x = block2(x, t_emb)
x = attn(x)
x = upsample(x)
x = self.out_block(torch.cat((x, r), dim=1), t_emb)
out = self.conv_out(x)
return {"sample": out}
去噪 U-Net
的任务是根据输入图像所处的时间步预测图像中的噪声。在时间步
t
t
t 下,噪声图像
x
t
x_t
xt 可以表示为干净图像
x
0
x_0
x0 和标准正态分布的随机噪声
ϵ
\epsilon
ϵ 的加权和。随着时间步
t
t
t 从 0
增加到
T
T
T,干净图像的权重逐渐减小,而随机噪声的权重逐渐增加。因此,为了推断噪声图像中的噪声,去噪 U-Net
需要知道噪声图像处于哪个时间步。
时间步是通过正弦和余弦函数进行嵌入的,类似于 Transformer
中的位置编码,生成一个 128
维的向量。然后,这些嵌入被扩展以匹配模型中各层图像特征的维度。例如,在第一个下采样块中,时间嵌入会广播成 (128, 64, 64)
的形状,然后与维度同样为(128, 64, 64)的图像特征相加。
(2) 通过实例化 UNet()
类创建去噪 U-Net
模型:
from unet_util import UNet
device="cuda" if torch.cuda.is_available() else "cpu"
resolution=64
model=UNet(3,hidden_dims=[128,256,512,1024],
image_size=resolution).to(device)
num=sum(p.numel() for p in model.parameters())
print("number of parameters: %.2fM" % (num/1e6,))
print(model)
输出结果如下所示,模型有超过 1.33
亿个参数:
4. 训练并使用去噪 U-Net 模型
在每个训练 epoch
中,遍历训练数据中的所有批次。对于每张图像,我们会随机选择一个时间步,并根据这个时间步值向训练数据中的干净图像添加噪声,生成噪声图像。然后,将这些噪声图像及其对应的时间步值输入到去噪 U-Net
模型中,以预测每张图像中的噪声。我们将预测噪声与真实噪声(即实际添加到图像中的噪声)进行比较,并调整模型参数,以最小化预测噪声与实际噪声之间的平均绝对误差。
训练完成后,我们将使用训练好的模型生成花卉图像。生成过程将分为 50
个推理步骤(即时间步设置为 980
, 960
, …, 20
, 0
)。从随机噪声开始,将其输入到训练好的模型中,得到一个噪声图像。然后将这个噪声图像再次输入到训练好的模型中进行去噪。重复这个过程 50
次,最终得到的图像与训练集中的花卉图像相似。
4.1 训练去噪 U-Net 模型
(1) 定义优化器和学习率调度器,用于训练过程:
from diffusers.optimization import get_scheduler
# 训练 100 个 epoch
num_epochs=100
# 使用 AdamW 优化器
optimizer=torch.optim.AdamW(model.parameters(),lr=0.0001,
betas=(0.95,0.999),weight_decay=0.00001,eps=1e-8)
# 使用 diffusers 库中的学习率调度器来控制学习率
lr_scheduler=get_scheduler(
"cosine",
optimizer=optimizer,
num_warmup_steps=300,
num_training_steps=(len(train_dataloader) * num_epochs))
AdamW
优化器是 Adam
优化器的一个变体,它将权重衰减与优化步骤解耦。与直接将权重衰减应用到梯度不同,AdamW
将权重衰减直接应用到优化步骤后的参数(权重)上。这一修改有助于通过防止衰减率与学习率一起调整,从而提高模型的泛化能力。使用 diffusers
库中的学习率调度器在训练过程中调整学习率。最初使用较高的学习率可以帮助模型跳出局部最小值,而在训练的后期逐渐降低学习率则有助于模型更稳定、更准确地收敛到全局最小值。get_scheduler()
函数在前 300
个训练步骤中,学习率从 0
线性增加到 0.0001
(在 AdamW
优化器中设置的学习率)。在 300
步后,学习率会根据余弦函数的值在 0.0001
到 0
之间递减。
(2) 训练模型 100
个 epoch
:
for epoch in range(num_epochs):
model.train()
tloss = 0
print(f"start epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"].to(device)*2-1
nums = clean_images.shape[0]
noise = torch.randn(clean_images.shape).to(device)
timesteps = torch.randint(0,
noise_scheduler.num_train_timesteps,
(nums, ),
device=device).long()
# 在训练集中的干净图像上添加噪声
noisy_images = noise_scheduler.add_noise(clean_images,
noise, timesteps)
# 使用去噪 U-Net 预测噪声
noise_pred = model(noisy_images, timesteps)["sample"]
# 将预测的噪声与实际噪声进行比较,计算损失
loss = torch.nn.functional.l1_loss(noise_pred, noise)
loss.backward()
# 调整参数
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
tloss += loss.detach().item()
if step%100==0:
print(f"step {step}, average loss {tloss/(step+1)}")
torch.save(model.state_dict(),'files/diffusion.pth')
4.2 使用训练好的模型生成花卉图像
(1) 为了生成花卉图像,我们将使用 50
个推理步骤。这意味着我们将在
t
0 t = 0 t=0 和 t
T t = T t=T 之间(在本节中, T
1000
T = 1000
T=1000 )选择 50
个等间隔的时间步。我们从纯随机噪声开始,这对应于
t
1000
t = 1000
t=1000 时的图像;然后我们使用训练好的去噪 U-Net
模型对其进行去噪,并生成
t
980 t = 980 t=980 时的噪声图像;接下来,我们将 t
980 t = 980 t=980 时的噪声图像输入到训练好的模型中进行去噪,并得到 t
960 时的噪声图像 t = 960 时的噪声图像 t=960时的噪声图像。多次重复以上过程,直到获得 t
0
t = 0
t=0 时的图像,即一张干净图像。这一过程是通过模块 util.py
中 DDIMScheduler()
类的 generate()
方法实现:
@torch.no_grad()
def generate(self,model,device,batch_size=1,generator=None,
eta=1.0,use_clipped_model_output=True,num_inference_steps=50):
imgs=[]
# 使用随机噪声作为起始点
image=torch.randn((batch_size,model.in_channels,model.sample_size,
model.sample_size),generator=generator).to(device)
self.set_timesteps(num_inference_steps)
# 使用 50 个推理时间步长
for t in tqdm(self.timesteps):
# 使用训练好的去噪 U-Net 模型来预测噪声
model_output = model(image, t)["sample"]
# 基于预测的噪声创建图像
image = self.step(model_output,t,image,eta,
use_clipped_model_output=use_clipped_model_output)
img = unnormalize_to_zero_to_one(image)
img = img.cpu().permute(0, 2, 3, 1).numpy()
# 将中间图像保存在列表 imgs 中
imgs.append(img)
image = unnormalize_to_zero_to_one(image)
image = image.cpu().permute(0, 2, 3, 1).numpy()
return {"sample": image}, imgs
(2) 使用 generate()
方法来生成 10
张干净的图像:
sd=torch.load('files/diffusion.pth', weights_only=False)
model.load_state_dict(sd)
with torch.no_grad():
# 设置随机种子
generator = torch.manual_seed()
# 生成 10 张干净图像
generated_images,imgs = noise_scheduler.generate(
model,device,
num_inference_steps=50,
generator=generator,
eta=1.0,
use_clipped_model_output=True,
batch_size=10)
imgnp=generated_images["sample"]
import matplotlib.pyplot as plt
plt.figure(figsize=(10,4),dpi=300)
# 可视化结果
for i in range(10):
ax = plt.subplot(2,5, i + 1)
plt.imshow(imgnp[i])
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.show()
(3) generate()
方法还返回列表 imgs
,该列表包含 50
个推理步骤中的所有图像,我们可以使用这些图像来可视化去噪过程:
# 保留时间步长 800、600、400、200 和 0
steps=imgs[9::10]
# 从 10 张图像中选择 4 组花卉
imgs20=[]
for j in [1,3,6,9]:
for i in range(5):
imgs20.append(steps[i][j])
# 可视化结果
plt.figure(figsize=(10,8),dpi=300)
for i in range(20):
k=i%5
ax = plt.subplot(4,5, i + 1)
plt.imshow(imgs20[i])
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.title(f't={800-200*k}',fontsize=15,c="r")
plt.show()
选择 5
个时间步 (t = 800, 600, 400, 200, 0)
来展示四种不同的花卉,第一列展示了在 t = 800
时的四张花卉图像。它们接近随机噪声。随着推理步骤的进行,图像变得越来越清晰。最右列显示了在 t = 0
时的四张清晰花卉图像。
小结
- 在前向扩散中,逐渐向干净图像添加少量随机噪声,直到它们转变为纯噪声。相反,在逆向扩散中,从随机噪声开始,使用去噪模型逐步消除图像中的噪声,将噪声转变回干净图像
U-Net
架构最初设计用于生物医学图像分割,具有对称形状,包含一个收缩的编码器路径和一个扩张的解码器路径,通过瓶颈层连接。在去噪中,U-Net
被调整为去除噪声的同时保留细节。跳跃连接将编码器和解码器的特征图链接在一起,有助于保留在编码过程中可能丢失的空间信息(如边缘和纹理)- 将注意力机制集成到去噪
U-Net
模型中,可以使其专注于重要的特征,忽略不相关的特征。通过将图像像素视为序列,注意力机制学习像素之间的依赖关系,类似于在自然语言处理中学习词元之间的依赖关系,增强了模型有效识别相关特征的能力