1. VAE 代码板块
1. net.py
网络结构,额为了模仿厂的load方式,方便以后去加模块hh
from torch import nn
import torch
# 新增ResnetBlock模块
class ResnetBlock(nn.Module):
def __init__(self, in_channels, out_channels, downsample=False):
super().__init__()
stride = 2 if downsample else 1
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.InstanceNorm2d(out_channels, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.InstanceNorm2d(out_channels, affine=True)
self.relu = nn.ReLU(inplace=True)
self.shortcut = nn.Sequential()
if in_channels != out_channels or downsample:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.InstanceNorm2d(out_channels)
)
def forward(self, x):
identity = self.shortcut(x)
x = self.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
x += identity
return self.relu(x)
class Encoder(nn.Module):
def __init__(self, configs):
super().__init__()
layers = []
current_channels = 3 # 默认输入通道
# 解析配置结构
for block_type, params in configs['down']:
if block_type == 'ResnetBlock':
layers.append(ResnetBlock(current_channels, *params))
current_channels = params[0]
elif block_type == 'Downsample':
layers.append(nn.MaxPool2d(*params))
else:
layer = getattr(nn, block_type)(*params)
if 'Conv' in block_type:
current_channels = params[1]
layers.append(layer)
self.layer = nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
class Decoder(nn.Module):
def __init__(self, configs):
super().__init__()
layers = []
current_channels = configs['init_channels']
# 解析上采样配置
for block_type, params in configs['up']:
if block_type == 'UpBlock':
out_channels = params[0]
# 转置卷积部分
layers.append(nn.ConvTranspose2d(
current_channels, out_channels,
kernel_size=3, stride=2, padding=1, output_padding=1
))
# layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU())
# 后接3x3卷积
layers.append(nn.Conv2d(out_channels, out_channels, 3, padding=1))
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.ReLU())
current_channels = out_channels
else:
layer = getattr(nn, block_type)(*params)
layers.append(layer)
self.layer = nn.Sequential(*layers)
# nn.Sigmoid
def forward(self, x):
return self.layer(x)
class VAE(nn.Module):
def __init__(self, configs):
super().__init__()
self.encoder = Encoder(configs)
# 中间层处理
self.mu = nn.Conv2d(*configs['mid']['mu'])
self.sigma = nn.Conv2d(*configs['mid']['sigma'])
# 解码器初始化
decoder_config = {
'init_channels': configs['mid']['mu'][1],
'up': configs['up']
}
self.decoder = Decoder(decoder_config)
def reparameterize(self, mu, sigmoid):
std = torch.exp(0.5*sigmoid)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# 编码
x = self.encoder(x)
# 潜在空间参数
mu = self.mu(x)
sigma = self.sigma(x)
z = self.reparameterize(mu, sigma)
return self.decoder(z), mu, sigma
if __name__ == '__main__':
# 新的配置结构示例
configs = {
'down': [
('ResnetBlock', [64, False]), # (out_channels, downsample)
('ResnetBlock', [128, True]),
('ResnetBlock', [256, True]),
],
'mid': {
'mu': [256, 48, 1, 1, 0], # Conv2d参数
'sigma': [256, 48, 1, 1, 0] # Conv2d参数
},
'up': [
('UpBlock', [128]), # 上采样块参数 (out_channels)
('Conv2d', [128, 128, 3, 1, 1]),
('UpBlock', [64]),
('Conv2d', [64, 64, 3, 1, 1]),
('Conv2d', [64, 3, 3, 1, 1]),
('Sigmoid', [])
]
}
vae = VAE(configs)
x = torch.randn(1, 3, 256, 256)
recon, mu, logvar = vae(x)
print("Input shape:", x.shape)
print("Recon shape:", recon.shape) # 应保持与输入相同
img = torch.clip(x[0].permute(1, 2, 0) * 255, 0, 255)
img = img.numpy().astype('uint8')
from PIL import Image
Image.fromarray(img).save('1.png')
# print("Mu shape:", mu.shape) # [2, 48, H, W]2. loss.py
两个损失嘛,一个是计算中间分布的损失,一个是计算输出的损失。
from torch import nn
from torch.nn import functional as F
import torch
class VAELoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, recon_x, mu, logvar):
logvar = torch.clamp(logvar, min=-30, max=20) # 防止exp溢出
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD3. config.py
模型配置文件,可以改成 yaml 啊之类的东西,反正自己改代码吧。
configs = {
'down': [
('ResnetBlock', [64, False]),
('ResnetBlock', [128, True]),
('ResnetBlock', [256, True]),
],
'mid': {
'mu': [256, 48, 1, 1, 0],
'sigma': [256, 48, 1, 1, 0]
},
'up': [
('UpBlock', [128]),
('Conv2d', [128, 128, 3, 1, 1]),
('UpBlock', [64]),
('Conv2d', [64, 64, 3, 1, 1]),
('Conv2d', [64, 3, 3, 1, 1]),
('Sigmoid', [])
]
}4. train.py
反正我损失降不下去,我随便丢的图片进去训练的。
from loss import VAELoss, torch, nn
from net import VAE
from dataloader import create_dataloader
from config import configs
from tqdm import tqdm
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'
def train(model:nn.Module, loss_fn, opt, dataloader, epochs=10, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
batch_nums = 20 # 累计梯度批次数
for epoch in range(1, epochs + 1):
model.train()
running_loss = 0.0
epoch_bar = tqdm(enumerate(dataloader), total=len(dataloader),
desc=f"Epoch {epoch}/{epochs}", leave=True)
for index, batch in epoch_bar:
if isinstance(batch, (list, tuple)):
images = batch[0]
else:
images = batch
images = images.to(device)
recon, mu, logvar = model(images)
# 在训练循环中调整Clip操作
recon = torch.clamp(recon, min=1e-7, max=1.0 - 1e-5) # 保留梯度
images = torch.clamp(images, min=0.0, max=1.0) # 防御性处理
loss = loss_fn(images, recon, mu, logvar)
loss.backward()
if (index + 1) % batch_nums == 0 or (index + 1) == len(dataloader):
opt.step()
opt.zero_grad()
running_loss += loss.item()
epoch_bar.set_postfix(batch_loss=loss.item(), avg_loss=running_loss / (index + 1))
print(f"[Epoch {epoch}] Avg Loss: {running_loss / len(dataloader):.4f}")
torch.save(model.state_dict(), f"{epoch}.pth")
if __name__ == "__main__":
loss_fn = VAELoss()
vae = VAE(configs)
opt = torch.optim.AdamW(vae.parameters(), lr=0.03)
train(vae, loss_fn, opt, create_dataloader("/home/xiongjiexing/inputs/LoraDataset"), 10, 'cuda')5. dataloader
加载数据的模块。
from torch.utils.data import dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class VaeDataSet(dataset.Dataset):
def __init__(self, path):
super(VaeDataSet, self).__init__()
self.transform = transforms.Compose([
transforms.ToTensor()
])
valid_file_format = ('.jpg', '.png')
self.imgs = []
for folder, _, list_filename in os.walk(path):
for filename in list_filename:
if not filename.endswith(valid_file_format):
continue
self.imgs.append(os.path.join(folder, filename))
def __getitem__(self, index):
img = Image.open(self.imgs[index])
w, h = img.size
w = w //4 * 4
h = h //4 * 4
return self.transform(img.resize((w, h)))
def __len__(self):
return len(self.imgs)
def create_dataloader(path, batch_size=1):
return DataLoader(VaeDataSet(path), batch_size, shuffle=True)
评论列表
${content}