1. VAE 代码板块

教主 站长
发表于图像生成分类

1. net.py

网络结构,额为了模仿厂的load方式,方便以后去加模块hh

from torch import nn
import torch

# 新增ResnetBlock模块
class ResnetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False):
        super().__init__()
        stride = 2 if downsample else 1
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.InstanceNorm2d(out_channels, affine=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.InstanceNorm2d(out_channels, affine=True)
        self.relu = nn.ReLU(inplace=True)
        
        self.shortcut = nn.Sequential()
        if in_channels != out_channels or downsample:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.InstanceNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.shortcut(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += identity
        return self.relu(x)

class Encoder(nn.Module):
    def __init__(self, configs):
        super().__init__()
        layers = []
        current_channels = 3  # 默认输入通道
        
        # 解析配置结构
        for block_type, params in configs['down']:
            if block_type == 'ResnetBlock':
                layers.append(ResnetBlock(current_channels, *params))
                current_channels = params[0]
            elif block_type == 'Downsample':
                layers.append(nn.MaxPool2d(*params))
            else:
                layer = getattr(nn, block_type)(*params)
                if 'Conv' in block_type:
                    current_channels = params[1]
                layers.append(layer)
        self.layer = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layer(x)

class Decoder(nn.Module):
    def __init__(self, configs):
        super().__init__()
        layers = []
        current_channels = configs['init_channels']
        
        # 解析上采样配置
        for block_type, params in configs['up']:
            if block_type == 'UpBlock':
                out_channels = params[0]
                # 转置卷积部分
                layers.append(nn.ConvTranspose2d(
                    current_channels, out_channels,
                    kernel_size=3, stride=2, padding=1, output_padding=1
                ))
                # layers.append(nn.BatchNorm2d(out_channels))
                layers.append(nn.ReLU())
                # 后接3x3卷积
                layers.append(nn.Conv2d(out_channels, out_channels, 3, padding=1))
                layers.append(nn.BatchNorm2d(out_channels))
                layers.append(nn.ReLU())
                current_channels = out_channels
            else:
                layer = getattr(nn, block_type)(*params)
                layers.append(layer)
        self.layer = nn.Sequential(*layers)
    # nn.Sigmoid
    def forward(self, x):
        return self.layer(x)

class VAE(nn.Module):
    def __init__(self, configs):
        super().__init__()
        self.encoder = Encoder(configs)
        
        # 中间层处理
        self.mu = nn.Conv2d(*configs['mid']['mu'])
        self.sigma = nn.Conv2d(*configs['mid']['sigma'])
        
        # 解码器初始化
        decoder_config = {
            'init_channels': configs['mid']['mu'][1],
            'up': configs['up']
        }
        self.decoder = Decoder(decoder_config)
        
    def reparameterize(self, mu, sigmoid):
        std = torch.exp(0.5*sigmoid)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        # 编码
        x = self.encoder(x)
        
        # 潜在空间参数
        mu = self.mu(x)
        sigma = self.sigma(x)
        z = self.reparameterize(mu, sigma)

        return self.decoder(z), mu, sigma

if __name__ == '__main__':
    # 新的配置结构示例
    configs = {
        'down': [
            ('ResnetBlock', [64, False]),  # (out_channels, downsample)
            ('ResnetBlock', [128, True]),
            ('ResnetBlock', [256, True]),
        ],
        'mid': {
            'mu': [256, 48, 1, 1, 0],    # Conv2d参数
            'sigma': [256, 48, 1, 1, 0]  # Conv2d参数
        },
        'up': [
            ('UpBlock', [128]),  # 上采样块参数 (out_channels)
            ('Conv2d', [128, 128, 3, 1, 1]),
            ('UpBlock', [64]),
            ('Conv2d', [64, 64, 3, 1, 1]),
            ('Conv2d', [64, 3, 3, 1, 1]),
            ('Sigmoid', [])
        ]
    }
    
    vae = VAE(configs)
    x = torch.randn(1, 3, 256, 256)
    recon, mu, logvar = vae(x)
    print("Input shape:", x.shape)
    print("Recon shape:", recon.shape)  # 应保持与输入相同
    
    
    img = torch.clip(x[0].permute(1, 2, 0) * 255, 0, 255)
    img = img.numpy().astype('uint8')
    from PIL import Image
    Image.fromarray(img).save('1.png')
    # print("Mu shape:", mu.shape)       # [2, 48, H, W]

2. loss.py

两个损失嘛,一个是计算中间分布的损失,一个是计算输出的损失。

from torch import nn
from torch.nn import functional as F
import torch

class VAELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, recon_x, mu, logvar):
        logvar = torch.clamp(logvar, min=-30, max=20)  # 防止exp溢出
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

3. config.py

模型配置文件,可以改成 yaml 啊之类的东西,反正自己改代码吧。

configs = {
    'down': [
        ('ResnetBlock', [64, False]),
        ('ResnetBlock', [128, True]),
        ('ResnetBlock', [256, True]),
    ],
    'mid': {
        'mu': [256, 48, 1, 1, 0],
        'sigma': [256, 48, 1, 1, 0]
    },
    'up': [
        ('UpBlock', [128]),
        ('Conv2d', [128, 128, 3, 1, 1]),
        ('UpBlock', [64]),
        ('Conv2d', [64, 64, 3, 1, 1]),
        ('Conv2d', [64, 3, 3, 1, 1]),
        ('Sigmoid', [])
    ]
}

4. train.py

反正我损失降不下去,我随便丢的图片进去训练的。

from loss import VAELoss, torch, nn
from net import VAE
from dataloader import create_dataloader
from config import configs
from tqdm import tqdm
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'

def train(model:nn.Module, loss_fn, opt, dataloader, epochs=10, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    batch_nums = 20  # 累计梯度批次数

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0

        epoch_bar = tqdm(enumerate(dataloader), total=len(dataloader),
                         desc=f"Epoch {epoch}/{epochs}", leave=True)

        for index, batch in epoch_bar:
            if isinstance(batch, (list, tuple)):
                images = batch[0]
            else:
                images = batch
            images = images.to(device)

            recon, mu, logvar = model(images)
            # 在训练循环中调整Clip操作
            recon = torch.clamp(recon, min=1e-7, max=1.0 - 1e-5)  # 保留梯度
            images = torch.clamp(images, min=0.0, max=1.0)        # 防御性处理
            loss = loss_fn(images, recon, mu, logvar)

            loss.backward()

            if (index + 1) % batch_nums == 0 or (index + 1) == len(dataloader):
                opt.step()
                opt.zero_grad()

            running_loss += loss.item()
            epoch_bar.set_postfix(batch_loss=loss.item(), avg_loss=running_loss / (index + 1))

        print(f"[Epoch {epoch}] Avg Loss: {running_loss / len(dataloader):.4f}")
        torch.save(model.state_dict(), f"{epoch}.pth")

if __name__ == "__main__":
    loss_fn = VAELoss()
    vae = VAE(configs)
    opt = torch.optim.AdamW(vae.parameters(), lr=0.03)
    train(vae, loss_fn, opt, create_dataloader("/home/xiongjiexing/inputs/LoraDataset"), 10, 'cuda')

5. dataloader

加载数据的模块。

from torch.utils.data import dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class VaeDataSet(dataset.Dataset):
    def __init__(self, path):
        super(VaeDataSet, self).__init__()
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
        valid_file_format = ('.jpg', '.png')
        self.imgs = []
        for folder, _, list_filename in os.walk(path):
            for filename in list_filename:
                if not filename.endswith(valid_file_format):
                    continue
                self.imgs.append(os.path.join(folder, filename))

    def __getitem__(self, index):
        img = Image.open(self.imgs[index])
        w, h = img.size
        
        w = w //4 * 4
        h = h //4 * 4
        
        return self.transform(img.resize((w, h)))

    def __len__(self):
        return len(self.imgs)

def create_dataloader(path, batch_size=1):
    return DataLoader(VaeDataSet(path), batch_size, shuffle=True)
评论列表
加载更多
登录 分类