項目地址: https://github.com/researchmm/AOT-GAN-for-Inpainting 基于pytorch實現(xiàn)
論文地址: https://arxiv.org/abs/2104.01431
開源時間: 2021年
項目簡介: AOT-GAN-for-Inpainting是一個開源的圖像修復(fù)項目,其對 Places2 數(shù)據(jù)集的效果表明,我們的模型在 FID 方面明顯優(yōu)于最先進(jìn)的模型,相對改進(jìn)了 1.8%。一項包括 365 多名受試者的用戶研究進(jìn)一步驗證了 AOT-GAN 的優(yōu)越性。我們進(jìn)一步評估了所提出的AOT-GAN在實際應(yīng)用中的應(yīng)用,例如,logo去除、面部修復(fù)和物體移除。結(jié)果表明,我們的模型在現(xiàn)實的廣泛數(shù)據(jù)數(shù)據(jù)中取得了良好的效果。
預(yù)訓(xùn)練模型:CELEBA-HQ |Places2
1、論文主要創(chuàng)新點(diǎn)
1.1 基本介紹
當(dāng)前的圖像修復(fù)方法可能會在高分辨率圖像(例如 512x512)中產(chǎn)生扭曲的結(jié)構(gòu)和模糊的紋理。這些挑戰(zhàn)主要來自:
(1)來自較遠(yuǎn)區(qū)域的圖像內(nèi)容推理,
(2)對大缺失區(qū)域的細(xì)粒度紋理合成。
為了克服這兩個挑戰(zhàn),提出了一種增強(qiáng)的基于GAN的模型,稱為(AOT-GAN),用于高分辨率圖像修復(fù)。具體來說,為了增強(qiáng)上下文推理,AOT-GAN-for-Inpainting通過堆疊所提出的 AOT 塊的多層來構(gòu)建 AOT-GAN 的生成器。AOT-block來自各種感受野的聚合上下文轉(zhuǎn)換,從而允許捕獲信息豐富的遠(yuǎn)距離圖像上下文和豐富的感興趣模式以進(jìn)行上下文推理。為了改善紋理合成,AOT-GAN-for-Inpainting通過使用量身定制的掩碼預(yù)測任務(wù)來訓(xùn)練AOT-GAN的判別器。這樣的訓(xùn)練目標(biāo)迫使判別器區(qū)分真實和合成補(bǔ)丁的詳細(xì)外觀,進(jìn)而促進(jìn)生成器合成清晰的紋理。
1.2 AOT-block
AOT-block是本文提出的一大創(chuàng)新點(diǎn),其認(rèn)為普通的殘差結(jié)構(gòu)無法捕捉的全局信息,因此提出一種類似于aspp的多尺度的孔洞卷積卷積結(jié)構(gòu),同時又將殘差結(jié)構(gòu)與類aspp結(jié)構(gòu)聯(lián)合在一起(以帶可訓(xùn)練權(quán)重的方式進(jìn)行聯(lián)合
)。這種aot-block結(jié)構(gòu)很適合進(jìn)行場景解析,其類assp結(jié)構(gòu)可以獲取多尺度全局信息,右側(cè)的分支可以按照正常的卷積模型提取特征,附帶的可訓(xùn)練參數(shù)g可以根據(jù)反向傳播調(diào)整多尺度全局信息與具備信息的比例。
其對應(yīng)的代碼實現(xiàn)如下
class AOTBlock(nn.Module):
def __init__(self, dim, rates):
super(AOTBlock, self).__init__()
self.rates = rates
for i, rate in enumerate(rates):
self.__setattr__(
'block{}'.format(str(i).zfill(2)),
nn.Sequential(
nn.ReflectionPad2d(rate),
nn.Conv2d(dim, dim//4, 3, padding=0, dilation=rate),
nn.ReLU(True)))
self.fuse = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(dim, dim, 3, padding=0, dilation=1))
self.gate = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(dim, dim, 3, padding=0, dilation=1))
def forward(self, x):
out = [self.__getattr__(f'block{str(i).zfill(2)}')(x) for i in range(len(self.rates))]
out = torch.cat(out, 1)
out = self.fuse(out)
mask = my_layer_norm(self.gate(x))
mask = torch.sigmoid(mask)
return x * (1 - mask) + out * mask
1.3 SM-PatchGAN
作者指出持相對于PatchGAN直接將整圖作為虛假目標(biāo),另一種掩模預(yù)測任務(wù)的另一種可能的設(shè)計HM-PatchGAN,如圖4所示,HMPatchGAN通過在不進(jìn)行高斯濾波的情況下進(jìn)行硬二值patch掩模訓(xùn)練,增強(qiáng)了PatchGAN鑒別器。HM-PatchGAN考慮了所修復(fù)圖像的原來真實部分,但忽略了mask的不規(guī)則性,其中標(biāo)簽為0中的部分patch中,尤其是靠近標(biāo)簽為1的patch,必然有部分是真實值。
作者推測這樣的設(shè)計會削弱鑒別器的訓(xùn)練。為了避免上述問題,所提出的SM-PatchGAN采用高斯濾波處理對HM-Patch進(jìn)行軟換。我們進(jìn)行了廣泛的消融研究,以顯示SM-PatchGAN的優(yōu)越性。
可以看出所提出的SM-PatchGAN方式能使FID有顯著提升
其進(jìn)行高斯模糊的代碼如下所示,具體作用在loss.py種的smgan loss中
def gaussian(window_size, sigma):
def gauss_fcn(x):
return -(x - window_size // 2)**2 / float(2 * sigma**2)
gauss = torch.stack([torch.exp(torch.tensor(gauss_fcn(x)))
for x in range(window_size)])
return gauss / gauss.sum()
def get_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
r"""Function that returns Gaussian filter coefficients.
Args:
kernel_size (int): filter size. It should be odd and positive.
sigma (float): gaussian standard deviation.
Returns:
Tensor: 1D tensor with gaussian filter coefficients.
Shape:
- Output: :math:`(\text{kernel_size})`
Examples::
>>> kornia.image.get_gaussian_kernel(3, 2.5)
tensor([0.3243, 0.3513, 0.3243])
>>> kornia.image.get_gaussian_kernel(5, 1.5)
tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
"""
if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
raise TypeError(
"kernel_size must be an odd positive integer. Got {}".format(kernel_size))
window_1d: torch.Tensor = gaussian(kernel_size, sigma)
return window_1d
def get_gaussian_kernel2d(kernel_size, sigma):
r"""Function that returns Gaussian filter matrix coefficients.
Args:
kernel_size (Tuple[int, int]): filter sizes in the x and y direction.
Sizes should be odd and positive.
sigma (Tuple[int, int]): gaussian standard deviation in the x and y
direction.
Returns:
Tensor: 2D tensor with gaussian filter matrix coefficients.
Shape:
- Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
Examples::
>>> kornia.image.get_gaussian_kernel2d((3, 3), (1.5, 1.5))
tensor([[0.0947, 0.1183, 0.0947],
[0.1183, 0.1478, 0.1183],
[0.0947, 0.1183, 0.0947]])
>>> kornia.image.get_gaussian_kernel2d((3, 5), (1.5, 1.5))
tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370],
[0.0462, 0.0899, 0.1123, 0.0899, 0.0462],
[0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
"""
if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
raise TypeError(
"kernel_size must be a tuple of length two. Got {}".format(kernel_size))
if not isinstance(sigma, tuple) or len(sigma) != 2:
raise TypeError(
"sigma must be a tuple of length two. Got {}".format(sigma))
ksize_x, ksize_y = kernel_size
sigma_x, sigma_y = sigma
kernel_x: torch.Tensor = get_gaussian_kernel(ksize_x, sigma_x)
kernel_y: torch.Tensor = get_gaussian_kernel(ksize_y, sigma_y)
kernel_2d: torch.Tensor = torch.matmul(
kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
return kernel_2d
class GaussianBlur(nn.Module):
r"""Creates an operator that blurs a tensor using a Gaussian filter.
The operator smooths the given tensor with a gaussian kernel by convolving
it to each channel. It suports batched operation.
Arguments:
kernel_size (Tuple[int, int]): the size of the kernel.
sigma (Tuple[float, float]): the standard deviation of the kernel.
Returns:
Tensor: the blurred tensor.
Shape:
- Input: :math:`(B, C, H, W)`
- Output: :math:`(B, C, H, W)`
Examples::
>>> input = torch.rand(2, 4, 5, 5)
>>> gauss = kornia.filters.GaussianBlur((3, 3), (1.5, 1.5))
>>> output = gauss(input) # 2x4x5x5
"""
def __init__(self, kernel_size, sigma):
super(GaussianBlur, self).__init__()
self.kernel_size = kernel_size
self.sigma = sigma
self._padding = self.compute_zero_padding(kernel_size)
self.kernel = get_gaussian_kernel2d(kernel_size, sigma)
@staticmethod
def compute_zero_padding(kernel_size):
"""Computes zero padding tuple."""
computed = [(k - 1) // 2 for k in kernel_size]
return computed[0], computed[1]
def forward(self, x): # type: ignore
if not torch.is_tensor(x):
raise TypeError(
"Input x type is not a torch.Tensor. Got {}".format(type(x)))
if not len(x.shape) == 4:
raise ValueError(
"Invalid input shape, we expect BxCxHxW. Got: {}".format(x.shape))
# prepare kernel
b, c, h, w = x.shape
tmp_kernel: torch.Tensor = self.kernel.to(x.device).to(x.dtype)
kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1)
# TODO: explore solution when using jit.trace since it raises a warning
# because the shape is converted to a tensor instead to a int.
# convolve tensor with gaussian kernel
return conv2d(x, kernel, padding=self._padding, stride=1, groups=c)
######################
# functional interface
######################
def gaussian_blur(input, kernel_size, sigma):
r"""Function that blurs a tensor using a Gaussian filter.
See :class:`~kornia.filters.GaussianBlur` for details.
"""
return GaussianBlur(kernel_size, sigma)(input)
2、模型結(jié)構(gòu)
2.1 生成器
在代碼src\model\aotgan.py 定義了模型的主要實現(xiàn)代碼
class InpaintGenerator(BaseNetwork):
def __init__(self, args): # 1046
super(InpaintGenerator, self).__init__()
self.encoder = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(4, 64, 7),
nn.ReLU(True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(True),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.ReLU(True)
)
self.middle = nn.Sequential(*[AOTBlock(256, args.rates) for _ in range(args.block_num)])
self.decoder = nn.Sequential(
UpConv(256, 128),
nn.ReLU(True),
UpConv(128, 64),
nn.ReLU(True),
nn.Conv2d(64, 3, 3, stride=1, padding=1)
)
self.init_weights()
def forward(self, x, mask):
x = torch.cat([x, mask], dim=1)
x = self.encoder(x)
x = self.middle(x)
x = self.decoder(x)
x = torch.tanh(x)
return x
其所對應(yīng)的網(wǎng)絡(luò)結(jié)構(gòu)如下所示,其中綠色的是middle,兩端的是編碼器與解碼器。
2.2 判別器
相比于復(fù)雜的生成器,判別器結(jié)構(gòu)比較簡單。其中比較特別的是spectral_norm,可以參考https://zhuanlan.zhihu.com/p/63957812。spectral_norm是pytorch自帶的頻譜歸一化函數(shù),給設(shè)定好的網(wǎng)絡(luò)進(jìn)行頻譜歸一化。其是用于在gan中,修改數(shù)據(jù)分布,使判別器 D 滿足利普希茨連續(xù)性,限制了函數(shù)變化的劇烈程度,從而使模型更穩(wěn)定,是訓(xùn)練gan網(wǎng)絡(luò)的一大利器。
在gan中,判別器訓(xùn)練越好,生成器梯度消失越嚴(yán)重。gan需要簡單而穩(wěn)定的判別器,使用spectral_norm可以達(dá)到這一目的。
class Discriminator(BaseNetwork):
def __init__(self, ):
super(Discriminator, self).__init__()
inc = 3
self.conv = nn.Sequential(
spectral_norm(nn.Conv2d(inc, 64, 4, stride=2, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(256, 512, 4, stride=1, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, stride=1, padding=1)
)
self.init_weights()
def forward(self, x):
feat = self.conv(x)
return feat
2.3 common.py
該代碼沒有重要信息,主要是實現(xiàn)對模型權(quán)重的初始化。
import torch
import torch.nn as nn
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
3、數(shù)據(jù)加載器
3.1 預(yù)訓(xùn)練模型
在論文中表述了一共在3個數(shù)據(jù)集上進(jìn)行訓(xùn)練,但僅發(fā)布了兩個預(yù)訓(xùn)練模型,關(guān)于logo移除的模型或許設(shè)計商業(yè)因素未公開。
CELEBA-HQ |Places2
其預(yù)訓(xùn)練模型數(shù)據(jù)的基本介紹如下
-
Places2[26]包含來自365種場景的180萬張圖片。由于其復(fù)雜的場景,它是圖像內(nèi)繪制中最具挑戰(zhàn)性的數(shù)據(jù)集之一。我們使用訓(xùn)練/測試的分割(即180萬/36500萬),遵循大多數(shù)內(nèi)繪畫模型[13,17,21]使用的設(shè)置。
-
CELEBA-HQ [50]是一個高質(zhì)量的人臉數(shù)據(jù)集。毛發(fā)和皮膚的高頻細(xì)節(jié)可以幫助我們評估模型的細(xì)粒度紋理合成。我們使用28,000張圖像進(jìn)行訓(xùn)練,使用2,000張圖像按照通用設(shè)置[13,17]進(jìn)行測試。
-
QMUL-OpenLogo [51]包含了來自352個logo類的27,083個圖片。每個圖像都有細(xì)粒度的標(biāo)識邊界框注釋。我們使用15,975張訓(xùn)練圖像進(jìn)行訓(xùn)練,使用2,777張驗證圖像進(jìn)行測試。
3.2 訓(xùn)練數(shù)據(jù)案例
詳情請參考https://blog.csdn.net/qq_45790998/article/details/128741301, 通過對數(shù)據(jù)案例的分析,進(jìn)行人臉修復(fù)應(yīng)該使用CELEBA-HQ模型,進(jìn)行通用圖像修改則使用Places2數(shù)據(jù)集。
CELEBA-HQ是一個由高分辨率人臉圖像和相關(guān)屬性標(biāo)簽組成的數(shù)據(jù)集。它包含了超過 30,000 張高分辨率(1024x1024)的人臉圖像,這些圖像來自于超過 1,000 位不同的名人。
Places2數(shù)據(jù)集是一個大型的場景圖像數(shù)據(jù)集,這個數(shù)據(jù)集共包含了405種不同場景類別的10萬張高質(zhì)量的場景圖像。
3.3 dataload代碼
其dataload的代碼如下,默認(rèn)是使用pconv的方式(帶mask的數(shù)據(jù)集|png圖片);對于不帶mask的圖片,修改args.mask_type為其他值,則默認(rèn)將圖像中央?yún)^(qū)域生成mask。
import os
import math
import numpy as np
from glob import glob
from random import shuffle
from PIL import Image, ImageFilter
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
class InpaintingData(Dataset):
def __init__(self, args):
super(Dataset, self).__init__()
self.w = self.h = args.image_size
self.mask_type = args.mask_type
# image and mask
self.image_path = []
for ext in ['*.jpg', '*.png']:
self.image_path.extend(glob(os.path.join(args.dir_image, args.data_train, ext)))
self.mask_path = glob(os.path.join(args.dir_mask, args.mask_type, '*.png'))
# augmentation
self.img_trans = transforms.Compose([
transforms.RandomResizedCrop(args.image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
transforms.ToTensor()])
self.mask_trans = transforms.Compose([
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(
(0, 45), interpolation=transforms.InterpolationMode.NEAREST),
])
def __len__(self):
return len(self.image_path)
def __getitem__(self, index):
# load image
image = Image.open(self.image_path[index]).convert('RGB')
filename = os.path.basename(self.image_path[index])
if self.mask_type == 'pconv':
index = np.random.randint(0, len(self.mask_path))
mask = Image.open(self.mask_path[index])
mask = mask.convert('L')
else:
mask = np.zeros((self.h, self.w)).astype(np.uint8)
mask[self.h//4:self.h//4*3, self.w//4:self.w//4*3] = 1
mask = Image.fromarray(mask).convert('L')
# augment
image = self.img_trans(image) * 2. - 1.
mask = F.to_tensor(self.mask_trans(mask))
return image, mask, filename
if __name__ == '__main__':
from attrdict import AttrDict
args = {
'dir_image': '../../../dataset',
'data_train': 'places2',
'dir_mask': '../../../dataset',
'mask_type': 'pconv',
'image_size': 512
}
args = AttrDict(args)
data = InpaintingData(args)
print(len(data), len(data.mask_path))
img, mask, filename = data[0]
print(img.size(), mask.size(), filename)
對于這種dataload,可以考慮隨機(jī)生成多邊形mask,來豐富訓(xùn)練數(shù)據(jù)。同時,在模型訓(xùn)練穩(wěn)定后改用復(fù)雜的transform進(jìn)行數(shù)據(jù)增強(qiáng)。
4、loss實現(xiàn)
4.1 具體代碼
其所對應(yīng)的loss有4種,Ladv對應(yīng)代碼中的nsgan函數(shù),也就是作者所提出的SM-PatchGAN部分。
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import VGG19, gaussian_blur
class L1():
def __init__(self,):
self.calc = torch.nn.L1Loss()
def __call__(self, x, y):
return self.calc(x, y)
class Perceptual(nn.Module):
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
super(Perceptual, self).__init__()
self.vgg = VGG19().cuda()
self.criterion = torch.nn.L1Loss()
self.weights = weights
def __call__(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
content_loss = 0.0
prefix = [1, 2, 3, 4, 5]
for i in range(5):
content_loss += self.weights[i] * self.criterion(
x_vgg[f'relu{prefix[i]}_1'], y_vgg[f'relu{prefix[i]}_1'])
return content_loss
class Style(nn.Module):
def __init__(self):
super(Style, self).__init__()
self.vgg = VGG19().cuda()
self.criterion = torch.nn.L1Loss()
def compute_gram(self, x):
b, c, h, w = x.size()
f = x.view(b, c, w * h)
f_T = f.transpose(1, 2)
G = f.bmm(f_T) / (h * w * c)
return G
def __call__(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
style_loss = 0.0
prefix = [2, 3, 4, 5]
posfix = [2, 4, 4, 2]
for pre, pos in list(zip(prefix, posfix)):
style_loss += self.criterion(
self.compute_gram(x_vgg[f'relu{pre}_{pos}']), self.compute_gram(y_vgg[f'relu{pre}_{pos}']))
return style_loss
class nsgan():
def __init__(self, ):
self.loss_fn = torch.nn.Softplus()
def __call__(self, netD, fake, real):
fake_detach = fake.detach()
d_fake = netD(fake_detach)
d_real = netD(real)
dis_loss = self.loss_fn(-d_real).mean() + self.loss_fn(d_fake).mean()
g_fake = netD(fake)
gen_loss = self.loss_fn(-g_fake).mean()
return dis_loss, gen_loss
class smgan():
def __init__(self, ksize=71):
self.ksize = ksize
self.loss_fn = nn.MSELoss()
def __call__(self, netD, fake, real, masks):
fake_detach = fake.detach()
g_fake = netD(fake)
d_fake = netD(fake_detach)
d_real = netD(real)
_, _, h, w = g_fake.size()
b, c, ht, wt = masks.size()
# Handle inconsistent size between outputs and masks
if h != ht or w != wt:
g_fake = F.interpolate(g_fake, size=(ht, wt), mode='bilinear', align_corners=True)
d_fake = F.interpolate(d_fake, size=(ht, wt), mode='bilinear', align_corners=True)
d_real = F.interpolate(d_real, size=(ht, wt), mode='bilinear', align_corners=True)
d_fake_label = gaussian_blur(masks, (self.ksize, self.ksize), (10, 10)).detach().cuda()
d_real_label = torch.zeros_like(d_real).cuda()
g_fake_label = torch.ones_like(g_fake).cuda()
dis_loss = self.loss_fn(d_fake, d_fake_label) + self.loss_fn(d_real, d_real_label)
gen_loss = self.loss_fn(g_fake, g_fake_label) * masks / torch.mean(masks)
return dis_loss.mean(), gen_loss.mean()
4.2 VGG19
在4.1中的3個loss函數(shù)中,都利用到了vgg19對數(shù)據(jù)提取特征,然后在計算loss。以下代碼在src\loss\common.py中,實現(xiàn)了對VGG19模型的分層編碼,抽取了VGG19種每一個stage中的conv的輸出。其中prefix 用于描述stage,posfix 用于描述stage中conv的位置。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.nn.functional import conv2d
class VGG19(nn.Module):
def __init__(self, resize_input=False):
super(VGG19, self).__init__()
features = models.vgg19(pretrained=True).features
self.resize_input = resize_input
self.mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
self.std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
prefix = [1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]
posfix = [1, 2, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]
names = list(zip(prefix, posfix))
self.relus = []
for pre, pos in names:
self.relus.append('relu{}_{}'.format(pre, pos))
self.__setattr__('relu{}_{}'.format(
pre, pos), torch.nn.Sequential())
nums = [[0, 1], [2, 3], [4, 5, 6], [7, 8],
[9, 10, 11], [12, 13], [14, 15], [16, 17],
[18, 19, 20], [21, 22], [23, 24], [25, 26],
[27, 28, 29], [30, 31], [32, 33], [34, 35]]
for i, layer in enumerate(self.relus):
for num in nums[i]:
self.__getattr__(layer).add_module(str(num), features[num])
# don't need the gradients, just want the features
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
# resize and normalize input for pretrained vgg19
x = (x + 1.0) / 2.0
x = (x - self.mean.view(1, 3, 1, 1)) / (self.std.view(1, 3, 1, 1))
if self.resize_input:
x = F.interpolate(
x, size=(256, 256), mode='bilinear', align_corners=True)
features = []
for layer in self.relus:
x = self.__getattr__(layer)(x)
features.append(x)
out = {key: value for (key, value) in list(zip(self.relus, features))}
return out
5、評價指標(biāo)
評價指標(biāo)相關(guān)的全部代碼在src\metric\metric.py中,具體有mae、psnr、ssim、fid。其中fid最為復(fù)雜,涉及了InceptionV3模型和calculate_activation_statistics、get_activations、calculate_frechet_distance三個函數(shù)。
其中代碼的亮點(diǎn),或可學(xué)習(xí)點(diǎn)在于其使用Pool.imap_unordered實現(xiàn)對數(shù)據(jù)的多線程處理,同時又利用tqdm實現(xiàn)了進(jìn)度條的顯示。
def compare_psnr(pairs):
real, fake = pairs
return peak_signal_noise_ratio(real, fake)
def psnr(reals, fakes, num_worker=8):
error = 0
pool = Pool(num_worker)
for val in tqdm(pool.imap_unordered(compare_psnr, zip(reals, fakes)), total=len(reals), desc='compare_psnr'):
error += val
return error / len(reals)
全部代碼如下:
import os
import pickle
import numpy as np
from tqdm import tqdm
from scipy import linalg
from multiprocessing import Pool
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio
import torch
from torch.autograd import Variable
from torch.nn.functional import adaptive_avg_pool2d
from .inception import InceptionV3
# ============================
def compare_mae(pairs):
real, fake = pairs
real, fake = real.astype(np.float32), fake.astype(np.float32)
return np.sum(np.abs(real - fake)) / np.sum(real + fake)
def compare_psnr(pairs):
real, fake = pairs
return peak_signal_noise_ratio(real, fake)
def compare_ssim(pairs):
real, fake = pairs
return structural_similarity(real, fake, multichannel=True)
# ================================
def mae(reals, fakes, num_worker=8):
error = 0
pool = Pool(num_worker)
for val in tqdm(pool.imap_unordered(compare_mae, zip(reals, fakes)), total=len(reals), desc='compare_mae'):
error += val
return error / len(reals)
def psnr(reals, fakes, num_worker=8):
error = 0
pool = Pool(num_worker)
for val in tqdm(pool.imap_unordered(compare_psnr, zip(reals, fakes)), total=len(reals), desc='compare_psnr'):
error += val
return error / len(reals)
def ssim(reals, fakes, num_worker=8):
error = 0
pool = Pool(num_worker)
for val in tqdm(pool.imap_unordered(compare_ssim, zip(reals, fakes)), total=len(reals), desc='compare_ssim'):
error += val
return error / len(reals)
def fid(reals, fakes, num_worker=8, real_fid_path=None):
dims = 2048
batch_size = 4
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx]).cuda()
if real_fid_path is None:
real_fid_path = 'places2_fid.pt'
if os.path.isfile(real_fid_path):
data = pickle.load(open(real_fid_path, 'rb'))
real_m, real_s = data['mu'], data['sigma']
else:
reals = (np.array(reals).astype(np.float32) / 255.0).transpose((0, 3, 1, 2))
real_m, real_s = calculate_activation_statistics(reals, model, batch_size, dims)
with open(real_fid_path, 'wb') as f:
pickle.dump({'mu': real_m, 'sigma': real_s}, f)
# calculate fid statistics for fake images
fakes = (np.array(fakes).astype(np.float32) / 255.0).transpose((0, 3, 1, 2))
fake_m, fake_s = calculate_activation_statistics(fakes, model, batch_size, dims)
fid_value = calculate_frechet_distance(real_m, real_s, fake_m, fake_s)
return fid_value
def calculate_activation_statistics(images, model, batch_size=64,
dims=2048, cuda=True, verbose=False):
"""Calculation of the statistics used by the FID.
Params:
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
must lie between 0 and 1.
-- model : Instance of inception model
-- batch_size : The images numpy array is split into batches with
batch size batch_size. A reasonable batch size
depends on the hardware.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- verbose : If set to True and parameter out_step is given, the
number of calculated batches is reported.
Returns:
-- mu : The mean over samples of the activations of the pool_3 layer of
the inception model.
-- sigma : The covariance matrix of the activations of the pool_3 layer of
the inception model.
"""
act = get_activations(images, model, batch_size, dims, cuda, verbose)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
def get_activations(images, model, batch_size=64, dims=2048, cuda=True, verbose=False):
"""Calculates the activations of the pool_3 layer for all images.
Params:
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
must lie between 0 and 1.
-- model : Instance of inception model
-- batch_size : the images numpy array is split into batches with
batch size batch_size. A reasonable batch size depends
on the hardware.
-- dims : Dimensionality of features returned by Inception
-- cuda : If set to True, use GPU
-- verbose : If set to True and parameter out_step is given, the number
of calculated batches is reported.
Returns:
-- A numpy array of dimension (num images, dims) that contains the
activations of the given tensor when feeding inception with the
query tensor.
"""
model.eval()
d0 = images.shape[0]
if batch_size > d0:
print(('Warning: batch size is bigger than the data size. '
'Setting batch size to data size'))
batch_size = d0
n_batches = d0 // batch_size
n_used_imgs = n_batches * batch_size
pred_arr = np.empty((n_used_imgs, dims))
for i in tqdm(range(n_batches), desc='calculate activations'):
if verbose:
print('\rPropagating batch %d/%d' %
(i + 1, n_batches), end='', flush=True)
start = i * batch_size
end = start + batch_size
batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
batch = Variable(batch)
if torch.cuda.is_available:
batch = batch.cuda()
with torch.no_grad():
pred = model(batch)[0]
# If model output is not scalar, apply global spatial average pooling.
# This happens if you choose a dimensionality not equal 2048.
if pred.shape[2] != 1 or pred.shape[3] != 1:
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
if verbose:
print(' done')
return pred_arr
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
Stable version by Dougal J. Sutherland.
Params:
-- mu1 : Numpy array containing the activations of a layer of the
inception net (like returned by the function 'get_predictions')
for generated samples.
-- mu2 : The sample mean over activations, precalculated on an
representive data set.
-- sigma1: The covariance matrix over activations for generated samples.
-- sigma2: The covariance matrix over activations, precalculated on an
representive data set.
Returns:
-- : The Frechet Distance.
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
6、使用項目
6.1 配置文件
使用項目進(jìn)行訓(xùn)練、驗證、測試的代碼在src\utils\option.py中,可以在此修改默認(rèn)配置。
import argparse
parser = argparse.ArgumentParser(description='Image Inpainting')
# data specifications
parser.add_argument('--dir_image', type=str, default='../../dataset',
help='image dataset directory')
parser.add_argument('--dir_mask', type=str, default='../../dataset',
help='mask dataset directory')
parser.add_argument('--data_train', type=str, default='places2',
help='dataname used for training')
parser.add_argument('--data_test', type=str, default='places2',
help='dataname used for testing')
parser.add_argument('--image_size', type=int, default=512,
help='image size used during training')
parser.add_argument('--mask_type', type=str, default='pconv',
help='mask used during training')
# model specifications
parser.add_argument('--model', type=str, default='aotgan',
help='model name')
parser.add_argument('--block_num', type=int, default=8,
help='number of AOT blocks')
parser.add_argument('--rates', type=str, default='1+2+4+8',
help='dilation rates used in AOT block')
parser.add_argument('--gan_type', type=str, default='smgan',
help='discriminator types')
# hardware specifications
parser.add_argument('--seed', type=int, default=2021,
help='random seed')
parser.add_argument('--num_workers', type=int, default=4,
help='number of workers used in data loader')
# optimization specifications
parser.add_argument('--lrg', type=float, default=1e-4,
help='learning rate for generator')
parser.add_argument('--lrd', type=float, default=1e-4,
help='learning rate for discriminator')
parser.add_argument('--optimizer', default='ADAM',
choices=('SGD', 'ADAM', 'RMSprop'),
help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--beta1', type=float, default=0.5,
help='beta1 in optimizer')
parser.add_argument('--beta2', type=float, default=0.999,
help='beta2 in optimier')
# loss specifications
parser.add_argument('--rec_loss', type=str, default='1*L1+250*Style+0.1*Perceptual',
help='losses for reconstruction')
parser.add_argument('--adv_weight', type=float, default=0.01,
help='loss weight for adversarial loss')
# training specifications
parser.add_argument('--iterations', type=int, default=1e6,
help='the number of iterations for training')
parser.add_argument('--batch_size', type=int, default=8,
help='batch size in each mini-batch')
parser.add_argument('--port', type=int, default=22334,
help='tcp port for distributed training')
parser.add_argument('--resume', action='store_true',
help='resume from previous iteration')
# log specifications
parser.add_argument('--print_every', type=int, default=10,
help='frequency for updating progress bar')
parser.add_argument('--save_every', type=int, default=1e4,
help='frequency for saving models')
parser.add_argument('--save_dir', type=str, default='../experiments',
help='directory for saving models and logs')
parser.add_argument('--tensorboard', action='store_true',
help='default: false, since it will slow training. use it for debugging')
# test and demo specifications
parser.add_argument('--pre_train', type=str, default=None,
help='path to pretrained models')
parser.add_argument('--outputs', type=str, default='../outputs',
help='path to save results')
parser.add_argument('--thick', type=int, default=15,
help='the thick of pen for free-form drawing')
parser.add_argument('--painter', default='freeform', choices=('freeform', 'bbox'),
help='different painters for demo ')
# ----------------------------------
args = parser.parse_args()
args.iterations = int(args.iterations)
args.rates = list(map(int, list(args.rates.split('+'))))
losses = list(args.rec_loss.split('+'))
args.rec_loss = {}
for l in losses:
weight, name = l.split('*')
args.rec_loss[name] = float(weight)
6.2 訓(xùn)練驗證測試
訓(xùn)練驗證測試代碼在src目錄下,由于其開源模型性能較好,不做深入研究。
參考官網(wǎng)教程即可進(jìn)行相應(yīng)操作
6.3 使用demo進(jìn)行圖像修改
到https://drive.google.com/drive/folders/1bSOH-2nB3feFRyDEmiX81CEiWkghss3i 下載作者發(fā)布的G模型,具體如下圖所示,并存放到src目錄下。
在src目錄下創(chuàng)建test_data目錄,并將自己的測試圖片(jpg或png后綴)存入。
將demo.py的代碼修改為以下形式
if __name__ == '__main__':
args.pre_train="src/G0000000.pt"
args.dir_image="src/test_data"
args.painter="bbox" #'freeform', 'bbox'
demo(args)
freeform表示自由涂繪,bbox表示繪制矩形。按下鼠標(biāo)即可在input窗口內(nèi)進(jìn)行繪圖,按空格鍵表示進(jìn)行圖像修復(fù),按 r 鍵表示情況mask重新繪圖,按 n 鍵表示進(jìn)入到下一個圖像,按 s 鍵表示保存圖像。文章來源:http://www.zghlxwxcb.cn/news/detail-815639.html
文章來源地址http://www.zghlxwxcb.cn/news/detail-815639.html
到了這里,關(guān)于AOT-GAN-for-Inpainting項目解讀|使用AOT-GAN進(jìn)行圖像修復(fù)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!