U-KAN模型的Pytorch实现
目录
U-KAN模型的Pytorch实现
一、DRIVE数据集
使用的是DRIVE视网膜血管分割数据集
二、U-KAN模型
三、MLP和KAN的区别
四、kan.py
import torch
import torch.nn.functional as F
import math
class KANLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.dim() == 2 and x.size(1) == self.in_features
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
return base_output + spline_output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
class KAN(torch.nn.Module):
def __init__(
self,
layers_hidden,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KAN, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order
self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinear(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)
def forward(self, x: torch.Tensor, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)
五、UKAN.py
import torch
from torch import nn
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
from utils import *
import timm
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import types
import math
from abc import ABCMeta, abstractmethod
# from mmcv.cnn import ConvModule
from pdb import set_trace as st
from kan import KANLinear, KAN
from torch.nn import init
class KANLayer(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., no_kan=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.dim = in_features
grid_size = 5
spline_order = 3
scale_noise = 0.1
scale_base = 1.0
scale_spline = 1.0
base_activation = torch.nn.SiLU
grid_eps = 0.02
grid_range = [-1, 1]
if not no_kan:
self.fc1 = KANLinear(
in_features,
hidden_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
self.fc2 = KANLinear(
hidden_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
self.fc3 = KANLinear(
hidden_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
# # TODO
# self.fc4 = KANLinear(
# hidden_features,
# out_features,
# grid_size=grid_size,
# spline_order=spline_order,
# scale_noise=scale_noise,
# scale_base=scale_base,
# scale_spline=scale_spline,
# base_activation=base_activation,
# grid_eps=grid_eps,
# grid_range=grid_range,
# )
else:
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.fc3 = nn.Linear(hidden_features, out_features)
# TODO
# self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv_1 = DW_bn_relu(hidden_features)
self.dwconv_2 = DW_bn_relu(hidden_features)
self.dwconv_3 = DW_bn_relu(hidden_features)
# # TODO
# self.dwconv_4 = DW_bn_relu(hidden_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
# pdb.set_trace()
B, N, C = x.shape
x = self.fc1(x.reshape(B * N, C))
x = x.reshape(B, N, C).contiguous()
x = self.dwconv_1(x, H, W)
x = self.fc2(x.reshape(B * N, C))
x = x.reshape(B, N, C).contiguous()
x = self.dwconv_2(x, H, W)
x = self.fc3(x.reshape(B * N, C))
x = x.reshape(B, N, C).contiguous()
x = self.dwconv_3(x, H, W)
# # TODO
# x = x.reshape(B,N,C).contiguous()
# x = self.dwconv_4(x, H, W)
return x
class KANBlock(nn.Module):
def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, no_kan=False):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim)
self.layer = KANLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
no_kan=no_kan)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.layer(self.norm2(x), H, W))
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class DW_bn_relu(nn.Module):
def __init__(self, dim=768):
super(DW_bn_relu, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
self.bn = nn.BatchNorm2d(dim)
self.relu = nn.ReLU()
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = self.bn(x)
x = self.relu(x)
x = x.flatten(2).transpose(1, 2)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class ConvLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super(ConvLayer, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class D_ConvLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super(D_ConvLayer, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, padding=1),
nn.BatchNorm2d(in_ch),
nn.ReLU(inplace=True),
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class UKAN(nn.Module):
def __init__(self, num_classes, input_channels=3, deep_supervision=False, img_size=224, patch_size=16, in_chans=3,
embed_dims=[256, 320, 512], no_kan=False,
drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[1, 1, 1], **kwargs):
super().__init__()
kan_input_dim = embed_dims[0]
self.encoder1 = ConvLayer(3, kan_input_dim // 8)
self.encoder2 = ConvLayer(kan_input_dim // 8, kan_input_dim // 4)
self.encoder3 = ConvLayer(kan_input_dim // 4, kan_input_dim)
self.norm3 = norm_layer(embed_dims[1])
self.norm4 = norm_layer(embed_dims[2])
self.dnorm3 = norm_layer(embed_dims[1])
self.dnorm4 = norm_layer(embed_dims[0])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.block1 = nn.ModuleList([KANBlock(
dim=embed_dims[1],
drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer
)])
self.block2 = nn.ModuleList([KANBlock(
dim=embed_dims[2],
drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer
)])
self.dblock1 = nn.ModuleList([KANBlock(
dim=embed_dims[1],
drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer
)])
self.dblock2 = nn.ModuleList([KANBlock(
dim=embed_dims[0],
drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer
)])
self.patch_embed3 = PatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed4 = PatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.decoder1 = D_ConvLayer(embed_dims[2], embed_dims[1])
self.decoder2 = D_ConvLayer(embed_dims[1], embed_dims[0])
self.decoder3 = D_ConvLayer(embed_dims[0], embed_dims[0] // 4)
self.decoder4 = D_ConvLayer(embed_dims[0] // 4, embed_dims[0] // 8)
self.decoder5 = D_ConvLayer(embed_dims[0] // 8, embed_dims[0] // 8)
self.final = nn.Conv2d(embed_dims[0] // 8, num_classes, kernel_size=1)
self.soft = nn.Softmax(dim=1)
def forward(self, x):
B = x.shape[0]
### Encoder
### Conv Stage
### Stage 1
out = F.relu(F.max_pool2d(self.encoder1(x), 2, 2))
t1 = out
### Stage 2
out = F.relu(F.max_pool2d(self.encoder2(out), 2, 2))
t2 = out
### Stage 3
out = F.relu(F.max_pool2d(self.encoder3(out), 2, 2))
t3 = out
### Tokenized KAN Stage
### Stage 4
out, H, W = self.patch_embed3(out)
for i, blk in enumerate(self.block1):
out = blk(out, H, W)
out = self.norm3(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
t4 = out
### Bottleneck
out, H, W = self.patch_embed4(out)
for i, blk in enumerate(self.block2):
out = blk(out, H, W)
out = self.norm4(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
### Stage 4
out = F.relu(F.interpolate(self.decoder1(out), scale_factor=(2, 2), mode='bilinear'))
out = torch.add(out, t4)
_, _, H, W = out.shape
out = out.flatten(2).transpose(1, 2)
for i, blk in enumerate(self.dblock1):
out = blk(out, H, W)
### Stage 3
out = self.dnorm3(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
out = F.relu(F.interpolate(self.decoder2(out), scale_factor=(2, 2), mode='bilinear'))
out = torch.add(out, t3)
_, _, H, W = out.shape
out = out.flatten(2).transpose(1, 2)
for i, blk in enumerate(self.dblock2):
out = blk(out, H, W)
out = self.dnorm4(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
out = F.relu(F.interpolate(self.decoder3(out), scale_factor=(2, 2), mode='bilinear'))
out = torch.add(out, t2)
out = F.relu(F.interpolate(self.decoder4(out), scale_factor=(2, 2), mode='bilinear'))
out = torch.add(out, t1)
out = F.relu(F.interpolate(self.decoder5(out), scale_factor=(2, 2), mode='bilinear'))
return self.final(out)
六、utils.py
import argparse
import torch.nn as nn
class qkv_transform(nn.Conv1d):
"""Conv1d for qkv_transform"""
def str2bool(v):
if v.lower() in ['true', 1]:
return True
elif v.lower() in ['false', 0]:
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
七、train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import os
from PIL import Image
from torchvision import transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
from UKAN import UKAN
#数据加载与预处理
class DriveDataset(Dataset):
def __init__(self,image_dir,mask_dir,label_dir,transform=None,return_name=False):
self.image_dir=image_dir
self.mask_dir=mask_dir
self.label_dir=label_dir
self.transform=transform
self.images=os.listdir(image_dir)
self.return_name=return_name
def __len__(self):
return len(self.images)
def __getitem__(self,index):
img_path=os.path.join(self.image_dir,self.images[index])
mask_path=os.path.join(self.mask_dir,self.images[index].replace('.tif','_mask.gif'))
file_name=self.images[index].split('_')[0]
label_name=f"{file_name}_manual1.gif"
label_path=os.path.join(self.label_dir,label_name)
image=Image.open(img_path).convert("RGB") #将原始图像转换为RGB三通道格式
mask=Image.open(mask_path).convert("L") #将掩码转换为灰度模式
label=Image.open(label_path).convert("L")
#转换为数组
image=np.array(image)
label=np.array(label)
mask=np.array(mask)
#应用掩膜
label=label*(mask>0)
#归一化
image=image.astype(np.float32)/255.0
label=(label>0).astype(np.float32) #二值化
#转换为pytorch张量
image=torch.tensor(image).permute(2,0,1)
label=torch.tensor(label).unsqueeze(0)
if self.transform is not None:
image=self.transform(image)
label=self.transform(label)
if self.return_name:
return image,label,self.images[index]
else:
return image,label
#计算dice系数
def dice_coeff(pred,target):
intersection=(pred*target).sum(dim=(1,2))
dice=(2.*intersection+1e-6)/(pred.sum(dim=(1,2))+target.sum(dim=(1,2))+1e-6)
return dice.mean()
def train():
transform=transforms.Compose([
transforms.Resize((512,512)),
])
#创建数据集
train_image_path= "C:\\Users\\Administrator\\PycharmProjects\\PythonProject2\\PythonProject\\Unet\DRIVE\\train\\images"
train_mask_path= "C:\\Users\\Administrator\\PycharmProjects\\PythonProject2\\PythonProject\\Unet\DRIVE\\train\\mask"
train_label_path= "C:\\Users\\Administrator\\PycharmProjects\\PythonProject2\\PythonProject\\Unet\DRIVE\\train\\1st_manual"
test_image_path= "C:\\Users\\Administrator\\PycharmProjects\\PythonProject2\\PythonProject\\Unet\DRIVE\\test\\images"
test_mask_path= "C:\\Users\\Administrator\\PycharmProjects\\PythonProject2\\PythonProject\\Unet\DRIVE\\test\\mask"
test_label_path= "C:\\Users\\Administrator\\PycharmProjects\\PythonProject2\\PythonProject\\Unet\DRIVE\\test\\1st_manual"
train_dataset=DriveDataset(train_image_path,train_mask_path,train_label_path,transform=transform)
test_dataset=DriveDataset(test_image_path,test_mask_path,test_label_path,transform=transform)
#数据加载器
batch_size=4
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=1,shuffle=False)
#初始化UKAN模型
model=UKAN(num_classes=1,input_channels=3,img_size=512)
#训练参数
epochs=20
lr=0.0001
#损失函数和优化器
criterion=nn.BCEWithLogitsLoss() #二值交叉熵损失函数
optimizer=optim.Adam(model.parameters(),lr=lr) #使用Adam优化器
#添加余弦退火学习率调度器
scheduler=optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=epochs,
eta_min=1e-5 #最小学习率
)
#记录指标
train_losses=[]
val_losses=[]
val_dices=[]
best_dice=0.0
print(f"训练样本:{len(train_dataset)},测试样本:{len(test_dataset)}")
print(f"批量大小:{batch_size},训练轮次:{epochs}")
for epoch in range(epochs):
model.train()
epoch_train_loss=0.0
#训练阶段
for images,labels in train_loader:
#前向传播
outputs=model(images)
loss=criterion(outputs,labels)
#反向传播
optimizer.zero_grad()
loss.backward() #批次平均损失值
optimizer.step()
epoch_train_loss+=loss.item()*images.size(0)
#计算平均损失值
epoch_train_loss=epoch_train_loss/len(train_loader.dataset)
train_losses.append(epoch_train_loss)
#验证阶段
model.eval()
val_loss=0.0
total_dice=0.0
with torch.no_grad():
for images,labels in test_loader:
outputs=model(images)
loss=criterion(outputs,labels)
val_loss+=loss.item()*images.size(0)
#计算dice
preds=(outputs>0.5).float()
dice=dice_coeff(preds,labels)
total_dice+=dice.item()
#计算平均验证指标
val_loss=val_loss/len(test_loader.dataset)
avg_dice=total_dice/len(test_loader)
val_losses.append(val_loss)
val_dices.append(avg_dice)
#保存最佳模型
if avg_dice>best_dice:
best_dice=avg_dice
torch.save(model.state_dict(), "unet_retina_best.pth")
print(f"保存最佳模型,Dice系数:{best_dice:.4f}")
print(f"Epoch [{epoch+1}/{epochs}] "
f"Train Loss:{epoch_train_loss:.4f},Val Loss:{val_loss:.4f},"
f"Val Dice:{avg_dice:.4f},LR:{scheduler.get_last_lr()[0]:.2e}")
#绘制训练集和验证集损失曲线
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.plot(train_losses,label="Train Loss")
plt.plot(val_losses,label="Val Loss")
plt.title("训练集和验证集损失曲线")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.grid()
plt.legend()
#Dice曲线图
plt.subplot(1,2,2)
plt.plot(val_dices,label="Val Dice")
plt.title("验证集的Dice系数曲线")
plt.xlabel("Epochs")
plt.ylabel("Dice")
plt.legend()
plt.grid()
plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\PythonProject2\\PythonProject\\Unet\\UKAN\\Loss_Dice_curve.png")
plt.show()
#评估阶段
model.eval()
total_dice=0.0
with torch.no_grad():
for images,labels in test_loader:
outputs=model(images)
preds=(outputs>0.5).float()
dice=dice_coeff(preds,labels)
total_dice+=dice.item()
print(f"测试集平均Dice系数:{total_dice/len(test_loader):.4f}")
#使用最佳模型进行预测并保存结果
predict_and_save(model,test_image_path,test_mask_path,test_label_path,transform)
return model
def predict_and_save(model,test_image_path,test_mask_path,test_label_path,transform):
save_dir="C:\\Users\\Administrator\\PycharmProjects\\PythonProject2\\PythonProject\\Unet\\UKAN\\test_pred"
#加载最佳模型权重
model.load_state_dict(torch.load("unet_retina_best.pth"))
model.eval()
#创建测试数据集
test_dataset=DriveDataset(test_image_path,test_mask_path,test_label_path,transform=transform,return_name=True)
test_loader=DataLoader(test_dataset,batch_size=1,shuffle=False)
with torch.no_grad():
for i,(images,labels,image_names) in enumerate(test_loader):
#预测
outputs=model(images)
preds=(torch.sigmoid(outputs)>0.5).float()
#转换为numpy数组并调整形状
pred_np=preds.squeeze().cpu().numpy()
#转换为PIL图像并保存
pred_image=Image.fromarray((pred_np*255).astype(np.uint8))
#生成保存路径
image_name=image_names[0].replace(".tif","_pred.png")
save_path=os.path.join(save_dir,image_name)
#保存预测结果
pred_image.save(save_path)
if __name__=="__main__":
model=train()
八、运行结果
最后运行train.py:
训练样本:20,测试样本:20
批量大小:4,训练轮次:20
保存最佳模型,Dice系数:0.0630
Epoch [1/20] Train Loss:0.7768,Val Loss:0.7093,Val Dice:0.0630,LR:1.00e-04
Epoch [2/20] Train Loss:0.7220,Val Loss:0.7057,Val Dice:0.0630,LR:1.00e-04
Epoch [3/20] Train Loss:0.6849,Val Loss:0.6952,Val Dice:0.0630,LR:1.00e-04
Epoch [4/20] Train Loss:0.6539,Val Loss:0.6722,Val Dice:0.0630,LR:1.00e-04
Epoch [5/20] Train Loss:0.6274,Val Loss:0.6600,Val Dice:0.0630,LR:1.00e-04
Epoch [6/20] Train Loss:0.6060,Val Loss:0.6515,Val Dice:0.0630,LR:1.00e-04
Epoch [7/20] Train Loss:0.5901,Val Loss:0.6306,Val Dice:0.0630,LR:1.00e-04
保存最佳模型,Dice系数:0.0630
Epoch [8/20] Train Loss:0.5755,Val Loss:0.5947,Val Dice:0.0630,LR:1.00e-04
保存最佳模型,Dice系数:0.0637
Epoch [9/20] Train Loss:0.5643,Val Loss:0.5596,Val Dice:0.0637,LR:1.00e-04
保存最佳模型,Dice系数:0.0659
Epoch [10/20] Train Loss:0.5537,Val Loss:0.5408,Val Dice:0.0659,LR:1.00e-04
保存最佳模型,Dice系数:0.0757
Epoch [11/20] Train Loss:0.5546,Val Loss:0.5311,Val Dice:0.0757,LR:1.00e-04
保存最佳模型,Dice系数:0.0843
Epoch [12/20] Train Loss:0.5448,Val Loss:0.5175,Val Dice:0.0843,LR:1.00e-04
保存最佳模型,Dice系数:0.1166
Epoch [13/20] Train Loss:0.5370,Val Loss:0.5195,Val Dice:0.1166,LR:1.00e-04
保存最佳模型,Dice系数:0.1839
Epoch [14/20] Train Loss:0.5286,Val Loss:0.5267,Val Dice:0.1839,LR:1.00e-04
保存最佳模型,Dice系数:0.1891
Epoch [15/20] Train Loss:0.5229,Val Loss:0.5237,Val Dice:0.1891,LR:1.00e-04
保存最佳模型,Dice系数:0.1962
Epoch [16/20] Train Loss:0.5190,Val Loss:0.5145,Val Dice:0.1962,LR:1.00e-04
保存最佳模型,Dice系数:0.2030
Epoch [17/20] Train Loss:0.5120,Val Loss:0.5051,Val Dice:0.2030,LR:1.00e-04
保存最佳模型,Dice系数:0.2138
Epoch [18/20] Train Loss:0.5071,Val Loss:0.5003,Val Dice:0.2138,LR:1.00e-04
保存最佳模型,Dice系数:0.2666
Epoch [19/20] Train Loss:0.5035,Val Loss:0.5016,Val Dice:0.2666,LR:1.00e-04
保存最佳模型,Dice系数:0.3075
Epoch [20/20] Train Loss:0.4979,Val Loss:0.5027,Val Dice:0.3075,LR:1.00e-04
测试集平均Dice系数:0.3075
进程已结束,退出代码为 0
loss和dice曲线图如下:
其中测试集里的一张预测分割图像如下: