下面的函数将创建一个目录来保存结果 。
def make_dir():image_dir = 'MNIST_Out_Images'if not os.path.exists(image_dir):os.makedirs(image_dir)使用下面的函数,我们将保存模型生成的重建图像 。
def save_decod_img(img, epoch):img = img.view(img.size(0), 1, 28, 28)save_image(img, './MNIST_Out_Images/Autoencoder_image{}.png'.format(epoch))将调用下面的函数来训练模型 。
def training(model, train_loader, Epochs):train_loss = []for epoch in range(Epochs):running_loss = 0.0for data in train_loader:img, _ = dataimg = img.to(device)img = img.view(img.size(0), -1)optimizer.zero_grad()outputs = model(img)loss = criterion(outputs, img)loss.backward()optimizer.step()running_loss += loss.item()loss = running_loss / len(train_loader)train_loss.Append(loss)print('Epoch {} of {}, Train Loss: {:.3f}'.format(epoch+1, Epochs, loss))if epoch % 5 == 0:save_decod_img(outputs.cpu().data, epoch)return train_loss以下函数将对训练后的模型进行图像重建测试 。
def test_image_reconstruct(model, test_loader):for batch in test_loader:img, _ = batchimg = img.to(device)img = img.view(img.size(0), -1)outputs = model(img)outputs = outputs.view(outputs.size(0), 1, 28, 28).cpu().datasave_image(outputs, 'MNIST_reconstruction.png')break在训练之前,模型将被推送到CUDA环境中,并使用上面定义的函数创建目录来保存结果图像 。
device = get_device()model.to(device)make_dir()现在,将对模型进行训练 。
train_loss = training(model, train_loader, Epochs)

文章插图

文章插图
训练成功后,我们将在训练中可视化损失 。
plt.figure()plt.plot(train_loss)plt.title('Train Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.savefig('deep_ae_mnist_loss.png')
文章插图
我们将可视化训练期间保存的一些图像 。
Image.open('/content/MNIST_Out_Images/Autoencoder_image0.png')
文章插图
Image.open('/content/MNIST_Out_Images/Autoencoder_image50.png')
文章插图
Image.open('/content/MNIST_Out_Images/Autoencoder_image95.png')
文章插图
在最后一步,我们将测试我们的自编码器模型来重建图像 。
test_image_reconstruct(model, testloader)Image.open('/content/MNIST_reconstruction.png')
文章插图
所以,我们可以看到,自训练过程开始时,自编码器模型就开始重建图像 。第一个epoch以后,重建的质量不是很好,直到50 epoch后才得到改进 。
经过完整的训练,我们可以看到,在95 epoch以后生成的图像和测试中,它可以构造出与原始输入图像非常匹配的图像 。
我们根据loss值,可以知道epoch可以设置100或200 。
经过长时间的训练,有望获得更清晰的重建图像 。然而,通过这个演示,我们可以理解如何在PyTorch中实现用于图像重建的深度自编码器 。
参考文献:
- Sovit Ranjan Rath, “Implementing Deep Autoencoder in PyTorch”
- Abien Fred Agarap, “Implementing an Autoencoder in PyTorch”
- Reyhane Askari, “Auto Encoders”
【在PyTorch中使用深度自编码器实现图像重建】
推荐阅读
- 使用Pytorch和Matplotlib可视化卷积神经网络的特征
- 火山爆发啦 这座火山随时可能爆发
- 为什么说读书的人越来越少?
- 人脑中控制人平衡力的是什么?
- JavaScript的深拷贝实现
- 火星内部全是外星人 外星人不存在的证据
- 爱斯基摩人真的存在吗 爱摩斯基人还存在吗
- 月球发现外星生命 月球是人造卫星吗
- 无脊椎动物有哪几种 有哪些无脊椎动物
- 什么是几亿年前大量的低等生物经过长期复杂变化形成的 单细胞生物最早出现在地球上的时间
