目錄
1、原理
2、代碼實(shí)現(xiàn)
1、原理
圖像風(fēng)格遷移是一種將一張圖片的內(nèi)容與另一張圖片的風(fēng)格進(jìn)行合成的技術(shù)。
風(fēng)格(style)是指圖像中不同空間尺度的紋理、顏色和視覺圖案,內(nèi)容(content)是指圖像的高級宏觀結(jié)構(gòu)。
實(shí)現(xiàn)風(fēng)格遷移背后的關(guān)鍵概念與所有深度學(xué)習(xí)算法的核心思想是一樣的:定義一個損失函數(shù)來指定想要實(shí)現(xiàn)的目標(biāo),然后將這個損失最小化。你知道想要實(shí)現(xiàn)的目標(biāo)是什么,就是保存原始圖像的內(nèi)容,同時采用參考圖像的風(fēng)格。
在Python中,我們可以使用基于深度學(xué)習(xí)的模型來實(shí)現(xiàn)這一技術(shù)。?神經(jīng)風(fēng)格遷移可以用任何預(yù)訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)來實(shí)現(xiàn)。我們這里將使用 ?? Gatys等人所使用的 VGG19網(wǎng)絡(luò)。
2、代碼實(shí)現(xiàn)
?以下是一個基于VGG19模型的簡單圖像風(fēng)格遷移的實(shí)現(xiàn)過程:
(1)創(chuàng)建一個網(wǎng)絡(luò),它能夠同時計算風(fēng)格參考圖像、目標(biāo)圖像和生成圖像的 VGG19層激活。
(2)使用這三張圖像上計算的層激活來定義之前所述的損失函數(shù),為了實(shí)現(xiàn)風(fēng)格遷移,需要將這個損失函數(shù)最小化。
(3)設(shè)置梯度下降過程來將這個損失函數(shù)最小化
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
?
import torch
import torch.optim as optim
from torchvision import transforms, models
?
vgg = models.vgg19(pretrained=True).features
?
for param in vgg.parameters():
param.requires_grad_(False)
?
?
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
?
vgg.to(device)
?
?
def load_image(img_path, max_size=400):
?
image = Image.open(img_path)
if max(image.size) > max_size:
size = max_size
else:
size = max(image.size)
image_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
?
image = image_transform(image).unsqueeze(0)
return image
?
?
content = load_image('dogs_and_cats.jpg').to(device)
style = load_image('picasso.jpg').to(device)
?
?
assert style.size() == content.size(), "輸入的風(fēng)格圖片和內(nèi)容圖片大小需要一致"
?
?
plt.ion()
def imshow(tensor,title=None):
image = tensor.cpu().clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1,2,0)
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.1)
?
plt.figure()
imshow(style, title='Style Image')
?
plt.figure()
imshow(content, title='Content Image')
?
?
?
?
def get_features(image, model, layers=None):
if layers is None:
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2',
'28': 'conv5_1'}
features = {}
x = image
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
?
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
?
def gram_matrix(tensor):
_, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
gram = torch.mm(tensor, tensor.t())
return gram
?
style_grams={}
for layer in style_features:
style_grams[layer] = gram_matrix(style_features[layer])
?
import torch.nn.functional as F
?
def ContentLoss(target_features,content_features):
content_loss = F.mse_loss(target_features['conv4_2'],content_features['conv4_2'])
return content_loss
?
def StyleLoss(target_features,style_grams,style_weights):
style_loss = 0
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * F.mse_loss(target_gram,style_gram)
style_loss += layer_style_loss / (d * h * w)
?
return style_loss
?
?
style_weights = {'conv1_1': 1.,
'conv2_1': 0.75,
'conv3_1': 0.2,
'conv4_1': 0.2,
'conv5_1': 0.2}
?
alpha = 1 # alpha
beta = 1e6 # beta
?
?
show_every = 100
steps = 2000
?
target = content.clone().requires_grad_(True).to(device)
optimizer = optim.Adam([target], lr=0.003)
?
?
for ii in range(1, steps+1):
target_features = get_features(target, vgg)
content_loss = ContentLoss(target_features,content_features)
style_loss = StyleLoss(target_features,style_grams,style_weights)
total_loss = alpha * content_loss + beta * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
#print(ii)
if ii % show_every == 0:
print('Total loss: ', total_loss.item())
plt.figure()
imshow(target)
?
plt.figure()
imshow(target,"Target Image")
plt.ioff()
plt.show()
?
文章來源:http://www.zghlxwxcb.cn/news/detail-735153.html
文章來源地址http://www.zghlxwxcb.cn/news/detail-735153.html
到了這里,關(guān)于python基于VGG19實(shí)現(xiàn)圖像風(fēng)格遷移的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!