22,298
社区成员
发帖
与我相关
我的任务
分享
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import torch.nn.functional as F
class ConvResidualLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvResidualLayer, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding='same')
self.gn1 = nn.GroupNorm(8, out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
self.gn2 = nn.GroupNorm(8, out_channels)
def forward(self, inputs):
inputs
residual = self.conv1(inputs)
x = self.gn1(residual)
x = F.relu(x)
x = self.conv2(x)
x = self.gn2(x)
out = x + residual
return out / 1.44
class SimpleDDPMModel(nn.Module):
def __init__(self, max_time_step=100):
super(SimpleDDPMModel, self).__init__()
self.max_time_step = max_time_step
betas = torch.linspace(1e-4, 0.02, max_time_step, dtype=torch.float32)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0)
betas_bar = 1.0 - alphas_bar
self.betas, self.alphas, self.alphas_bar, self.betas_bar = betas, alphas, alphas_bar, betas_bar
self.betas
self.alphas
self.alphas_bar
self.betas_bar
filter_nums = [64, 128, 256]
self.img_size = 32
self.encoders = nn.ModuleList([
nn.Sequential(
ConvResidualLayer(num_in, num_out),
nn.MaxPool2d(2)
) for num_in, num_out in zip([1] + filter_nums[:-1], filter_nums)])
self.mid_conv = ConvResidualLayer(filter_nums[-1], filter_nums[-1])
self.decoders = nn.ModuleList([
nn.Sequential(
nn.ConvTranspose2d(num_in, num_out, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvResidualLayer(num_out, num_out),
ConvResidualLayer(num_out, num_out)
) for num_in, num_out in zip(filter_nums[::-1][:], filter_nums[::-1][1:])
])
self.decoders.append(
nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
ConvResidualLayer(64, 64),
nn.Conv2d(64, 1, kernel_size=3, padding='same')
)
)
self.final_conv = nn.Sequential(
ConvResidualLayer(1, 64),
nn.Conv2d(64, 1, kernel_size=3, padding='same')
)
self.time_embeddings = nn.ModuleList([
nn.Sequential(
nn.Linear(1, num),
nn.LeakyReLU()
) for num in filter_nums])
def extract(self, sources, t):
bs = t.size(0)
targets = [source[t] for source in sources]
return tuple(map(lambda x: x.view(bs, 1, 1, 1), targets))
def q_noisy_sample(self, x_0, t, noisy):
alpha_bar, beta_bar = self.extract([self.alphas_bar, self.betas_bar], t)
sqrt_alpha_bar, sqrt_beta_bar = torch.sqrt(alpha_bar), torch.sqrt(beta_bar)
return sqrt_alpha_bar * x_0 + sqrt_beta_bar * noisy
def p_real_sample(self, x_t, t, pred_noisy):
bs = 16
alpha, beta, beta_bar = self.extract([self.alphas, self.betas, self.betas_bar], t)
noisy = torch.randn_like(x_t, device=x_t.device) # 指定 device 参数
noisy_weight = torch.sqrt(beta)
noisy_mask = 1 - torch.eq(t, 0).float().view(bs, 1, 1, 1)
noisy_weight *= noisy_mask
x_t_1 = ((x_t - beta * pred_noisy / torch.sqrt(beta_bar)) / torch.sqrt(alpha) + noisy * noisy_weight)
return x_t_1
def encoder(self, noisy_img, t):
xs = []
for idx, conv in enumerate(self.encoders):
noisy_img = conv(noisy_img)
t = t.float()
time_embedding = self.time_embeddings[idx](t)
time_embedding = time_embedding.view(-1, time_embedding.size(-1), 1, 1)
noisy_img += time_embedding
xs.append(noisy_img)
return xs, noisy_img
def decoder(self, noisy_img, xs):
xs.reverse()
for idx, conv in enumerate(self.decoders):
noisy_img = conv(noisy_img + xs[idx])
return noisy_img
def forward(self, inputs):
bs = inputs.size(0)
x_t = torch.randn_like(inputs) # 使用 torch.randn_like 创建与 inputs 相同设备的随机张量
for i in reversed(range(0, self.max_time_step)):
t = i * torch.ones(bs, 1).long()
p = self.pred_noisy({"img_data": x_t, "t": t})
x_t = self.p_real_sample(x_t, t, p["pred_noisy"])
return x_t
def pred_noisy(self, data):
img = data["img_data"]
bs = img.size(0)
noisy = torch.randn_like(img)
t = data.get("t", None)
if t is None:
t = torch.randint(low=0, high=self.max_time_step, size=(bs, 1), dtype=torch.long)
noisy_img = self.q_noisy_sample(img, t, noisy)
else:
noisy_img = img
xs, noisy_img = self.encoder(noisy_img, t)
x = self.mid_conv(xs[-1])
x = self.decoder(x, xs)
pred_noisy = self.final_conv(x)
return {
"pred_noisy": pred_noisy,
"noisy": noisy,
"loss": torch.mean(torch.sum((pred_noisy - noisy) ** 2, dim=(1, 2, 3)), dim=-1)
}
model = SimpleDDPMModel()
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 将模型移至GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 数据加载
data_path = "./data"
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((32, 32)),
transforms.Normalize((0.5,), (0.5,))
])
model = model.cuda()
train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
# 训练循环
num_epochs = 5
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as tqdm_loader:
for data in tqdm_loader:
inputs, _ = data
inputs = inputs.cuda()
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
# 计算损失
loss = criterion(outputs, inputs)
# 反向传播和优化
loss.backward()
optimizer.step()
# 更新总损失
total_loss += loss.item()
# 更新tqdm进度条
tqdm_loader.set_postfix(loss=loss.item())
# 打印本轮平均损失
average_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {average_loss:.4f}")
print("Training finished!")
已经尝试过在SimpleDDPM中的类定义里对所有变量,使用.cuda(),仍然不行。
搞python人工智能,得有个好显卡,因为常用的几个包是nvidia开发的,所以需要 nvidia显卡加配套的cuda驱动包,不配套版本号,就各种报错。