"> "> 计算机视觉-风格迁移 | Yufei Luo's Blog

计算机视觉-风格迁移

概述

图像风格迁移是一种用不同风格渲染图像语义内容的图像处理方法,简单地说,就是提供一幅表示风格的图片,然后在尽可能保留图像内容的情况下,将照片转化为该风格。传统非参数的风格迁移方法只能通过提取图像的低层次特征(色彩、纹理等)来进行纹理合成,无法提取图像高层的特征。非参数风格的方法包括基于笔划的渲染、图像类比、图像滤波、纹理合成等。

近年来,随着深度神经网络在其它视觉感知领域(例如物体与人脸识别)的广泛应用,其在风格迁移领域也具有良好的表现。神经网络具有良好的特征提取能力,能够提取到丰富的语义信息,而这些信息正是风格迁移的基础。

相关工作概述

早期使用神经网络做风格迁移的工作主要使用的是基于优化的方法。简单来说,就是输入一张随机噪音构成的底图,通过计算Style Loss和Content Loss,然后通过重复迭代来更改底图,使其风格纹理上与Style Image相似,内容上与原照片相似。Loss Network是一个已经训练好的网络,在迭代过程中锁定网络的参数,只更改底图的内容。这种方法的速度比较慢,需要迭代200-300次才能得到理想的结果。大致框架如下图所示:

preview

为了提高速度,在工作Perceptual Losses for Real-Time Style Transfer and Super-Resolution中,将风格迁移的过程使用一个自编码器来实现,大致框架如下:

preview

但上述方法仍有比较大的局限性:每个model只能进行一种Style image的迁移,每增加一张新的风格图需要训练一个新的模型。

工作StyleBank: An Explicit Representation for Neural Image Style Transfer可以在一个模型中实现多种风格迁移。它的思路是使用一个StyleBank参数层,保存每个风格所对应的参数,然后在风格迁移的过程中使用这些参数来控制图像风格,大致过程如下:

preview

如果需要添加新的风格,只需要在模型中添加一组新的参数,对应于新风格的参数。然后,便可以使用增量学习的方法对模型进行训练,即锁定其它参数,只对新风格的参数进行调整。

但是这一工作仍然无法实现任意风格的迁移,因此之后又有了一些相关工作,可以在网络训练完成之后实现任意的风格迁移。下文详细介绍其中的两个工作。

任意风格迁移

下面详细介绍两个能够进行任意风格迁移的工作。它们使用了不同的思路,在网络训练完成之后,便可以向网络中同时输入内容图和风格图,从而得到使用改风格渲染的内容图。

Universal Style Transfer via Feature Transforms

网络结构

image-20210729103637400

模型的整体结构如(a)所示,使用VGG网络作为Encoder,以及一个对称网络作为Decoder,然后训练这个网络以完成图像重建的任务。在DecoderX中,X​代表不同的解码器;而在VGG Relu_X_1中,X代表VGG网络不同层的输出。在训练完成之后,便固定它们的参数。

在此之后,作者提出了Single-level stylization和Multi-level stylization两种风格迁移的方式,分别如(b)和(c)所示。multi-level的方式可以获得更高的图片质量。其中,C代表内容图片,S代表风格图片,通过WCT(Whitening & Coloring Transform)操作将两种照片进行融合,从而实现风格迁移。

算法细节

Decoder

Decoder的结构被设计为VGG Encoder的对称结构,使用最近邻上采样的方式对特征图的尺寸进行扩大。为了评估VGG不同网络层输出的特征图对于重建的影响,作者取5个ReLU_X_1层的输出作为Encoder输出的特征图来训练Decoder。

训练过程中的损失函数为: \[ L=||I_o-I_i||_2^2+\lambda||\Phi(I_o)-\Phi(I_i)||_2^2 \] 其中,\(I_o\)\(I_i\)分别代表重建图像和输入图像,\(\Phi(\cdot)\)代表VGG编码器ReLU_X_1层的输出。\(\lambda\)为控制损失函数两部分的权重参数。

作者使用MS-COCO数据集来训练模型,训练过程中无需使用风格图片。

WCT

设内容图片C和风格图片S在经过Decode操作之后的输出分别为\(f_C\in \mathbb{R}^{C\times H_C\times W_C}\)\(f_S\in \mathbb{R}^{C\times H_C\times W_C}\).WCT操作的目的是对\(f_C\)进行转换,使其匹配\(f_S\)的协方差矩阵(即Gram矩阵)。WCT操作分为两步,即Whitening和Coloring transform。

在Whitening操作之前,需要对\(f_C\)做中心化操作,即对\(f_C\)减去它的平均向量\(m_C\)。之后做如下的线性变换: \[ \hat{f}_C=E_C D_C^{-\frac{1}{2}}E_C^Tf_C \] 其中,\(D_C\)\(E_C\)来源于协方差矩阵\(f_Cf_C^T\)的特征值分解:\(f_Cf_C^T=E_CD_CE_C^T\)

作者通过分析将Whitening后的特征输入到Decoder中得到的图片,说明Whitening操作的目的主要是保留图片中的内容特征(即图片包含的物体信息),而删除图片中的风格特征(即颜色等细节)

同样地,Coloring Transform操作同样需要先对\(f_S\)做中心化操作\(f_S=f_S-m_S\)。然后做如下的矩阵运算: \[ \hat{f}_{CS}=E_S D_S^{\frac{1}{2}}E_S^T \hat{f}_C \]

其中,\(D_S\)\(E_S\)来源于协方差矩阵\(f_Sf_S^T\)的特征值分解:\(f_Sf_S^T=E_S D_S E_S^T\)

通过这一变换,\(\hat{f}_{CS}\)与特征图满足如下关系:\(\hat{f}_{CS}\hat{f}_{CS}^T=f_S f_S^T\)

最后,需要再进行\(\hat{f}_{CS}=\hat{f}_{CS}+m_S\)这一运算,来完成重新中心化操作。

(Histogram matching)

在WCT操作之后,为了控制风格化程度,需要将\(\hat{f}_{CS}\)与内容\(f_C\)进行混合: \[ \hat{f}_{CS}=\alpha \hat{f}_{CS} +(1-\alpha)f_C \] \(\alpha\)​的值越低,代表越倾向保留原始图片的特征;而值越高则越倾向于将原始图片风格化。

小结

这一工作能够在不使用风格图片的基础上对网络进行训练,并在之后实现任意的风格迁移。但是需要训练5个不同大小的自编码器,使得模型的参数量较大。

Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization

Instance Normalization

在之前的风格迁移实现中,Ulyanov等人在BatchNorm的基础上提出了Instance Normalization,使得效果得到了极大提高。它的计算公式如下: \[ \text{IN}(x)=\gamma \left( \frac{x-\mu(x)}{\sigma(x)} \right)+\beta \] 但是与BatchNorm不同的是,Instance Norm则需要为每一个channel和每个sample都计算独立的\(\mu\)\(\sigma\)(BatchNorm只需要为每个channel计算即可): \[ \mu_{nc}(x)=\frac{1}{HW}\sum_{h=1}^H\sum_{w=1}^W x_{nchw} \\ \sigma_{nc}(x)=\sqrt{\frac{1}{HW}\sum_{h=1}^H\sum_{w=1}^W (x_{nchw}-\mu_{nc}(x))^2+\epsilon} \] 在此基础上,Dumoulin提出了Conditional Instance Normalization操作,为每个风格\(s\)设定独立的一组\(\gamma\)\(\beta\)的取值: \[ CIN(x;s)=\gamma^s \left( \frac{x-\mu(x)}{\sigma(x)} \right)+\beta^s \] 为了能够在不重新训练网络的前提下使用新的风格图片,作者提出了Adaptive Instance Normalization(AdaIN)操作: \[ AdaIN(x,y)=\sigma(y) \left( \frac{x-\mu(x)}{\sigma(x)} \right)+\mu(y) \]

算法细节

网络结构

image-20210729145457642

模型的整体结构比较简单,只包含了一个VGG Encoder和Decoder,并在二者中间加上了一个AdaIN层。Encoder取了VGG网络的前面几层(只到relu4_1),Decoder则取Encoder结构的镜像。为了避免棋盘效应,Decoder中所有的池化操作被替换为最近邻上采样。Decoder中如果使用normalization操作会影响性能,因此去掉了所有相关的层。

损失函数

模型训练的过程只需要调整Decoder的参数即可,损失函数为: \[ L=L_C+\lambda L_S \] 其中\(L_C\)代表内容损失,\(L_S\)代表风格损失,它们的定义如下: \[ L_C=||f(g(t))-t||_2 \\ L_S=\sum_{i=1}^{L}||\mu(\phi_i(g(t)))-\mu(\phi_i(s))||_2+\sum_{i=1}^{L}||\sigma(\phi_i(g(t)))-\sigma(\phi_i(s))||_2 \] 其中,\(f\)指的是Encoder,\(g\)指的是Decoder,\(t=AdaIN(f(c),f(s))\)\(\phi_i\)代表VGG-19中用于计算Style loss的那些网络层。

模型的训练过程使用MS-COCO作为内容图片,使用WikiArt作为风格图片进行训练。

小结

这一工作所使用的网络结构比较简单,但是训练过程中除了使用内容图片之外,还需要向网络中输入风格图片。但是在实际应用中,风格图片往往不如内容图片容易搜集。

程序示例

下面的程序是Universal Style Transfer via Feature Transforms这一工作的代码实现,由于整体结构中包含了5个大小不同的Autoencoder,而且要用COCO数据集进行训练,需要大量的计算资源。因此在下面的实现中,我们从https://github.com/pietrocarbo/deep-transfer中下载已经训练好的权重。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.linalg
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from models.autoencoder_vgg19.vgg19_1.vgg_normalised_conv1_1 import vgg_normalised_conv1_1
from models.autoencoder_vgg19.vgg19_2.vgg_normalised_conv2_1 import vgg_normalised_conv2_1
from models.autoencoder_vgg19.vgg19_3.vgg_normalised_conv3_1 import vgg_normalised_conv3_1
from models.autoencoder_vgg19.vgg19_4.vgg_normalised_conv4_1 import vgg_normalised_conv4_1
from models.autoencoder_vgg19.vgg19_5.vgg_normalised_conv5_1 import vgg_normalised_conv5_1

from models.autoencoder_vgg19.vgg19_1.feature_invertor_conv1_1 import feature_invertor_conv1_1
from models.autoencoder_vgg19.vgg19_2.feature_invertor_conv2_1 import feature_invertor_conv2_1
from models.autoencoder_vgg19.vgg19_3.feature_invertor_conv3_1 import feature_invertor_conv3_1
from models.autoencoder_vgg19.vgg19_4.feature_invertor_conv4_1 import feature_invertor_conv4_1
from models.autoencoder_vgg19.vgg19_5.feature_invertor_conv5_1 import feature_invertor_conv5_1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Encoder1(nn.Module):
def __init__(self):
super(Encoder1,self).__init__()
self.encoder=vgg_normalised_conv1_1
self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_1/vgg_normalised_conv1_1.pth'))
def forward(self,x):
return self.encoder(x)

class Decoder1(nn.Module):
def __init__(self):
super(Decoder1,self).__init__()
self.decoder=feature_invertor_conv1_1
self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_1/feature_invertor_conv1_1.pth'))
def forward(self,x):
return self.decoder(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Encoder2(nn.Module):
def __init__(self):
super(Encoder2,self).__init__()
self.encoder=vgg_normalised_conv2_1
self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_2/vgg_normalised_conv2_1.pth'))
def forward(self,x):
return self.encoder(x)

class Decoder2(nn.Module):
def __init__(self):
super(Decoder2,self).__init__()
self.decoder=feature_invertor_conv2_1
self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_2/feature_invertor_conv2_1.pth'))
def forward(self,x):
return self.decoder(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Encoder3(nn.Module):
def __init__(self):
super(Encoder3,self).__init__()
self.encoder=vgg_normalised_conv3_1
self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_3/vgg_normalised_conv3_1.pth'))
def forward(self,x):
return self.encoder(x)

class Decoder3(nn.Module):
def __init__(self):
super(Decoder3,self).__init__()
self.decoder=feature_invertor_conv3_1
self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_3/feature_invertor_conv3_1.pth'))
def forward(self,x):
return self.decoder(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Encoder4(nn.Module):
def __init__(self):
super(Encoder4,self).__init__()
self.encoder=vgg_normalised_conv4_1
self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_4/vgg_normalised_conv4_1.pth'))
def forward(self,x):
return self.encoder(x)

class Decoder4(nn.Module):
def __init__(self):
super(Decoder4,self).__init__()
self.decoder=feature_invertor_conv4_1
self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_4/feature_invertor_conv4_1.pth'))
def forward(self,x):
return self.decoder(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Encoder5(nn.Module):
def __init__(self):
super(Encoder5,self).__init__()
self.encoder=vgg_normalised_conv5_1
self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_5/vgg_normalised_conv5_1.pth'))
def forward(self,x):
return self.encoder(x)

class Decoder5(nn.Module):
def __init__(self):
super(Decoder5,self).__init__()
self.decoder=feature_invertor_conv5_1
self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_5/feature_invertor_conv5_1.pth'))
def forward(self,x):
return self.decoder(x)
1
2
Encoders=[Encoder1(),Encoder2(),Encoder3(),Encoder4(),Encoder5()]
Decoders=[Decoder1(),Decoder2(),Decoder3(),Decoder4(),Decoder5()]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class WCT(nn.Module):
def __init(self):
super(WCT,self).__init__()

def forward(self, f_c, f_s, alpha):
channels=f_c.shape[-3]
width=f_c.shape[-1]
height=f_c.shape[-2]

#Whitening transform
m_c=torch.mean(f_c,[-2,-1],keepdim=True)
f_c=f_c-m_c
f_c=torch.flatten(f_c,-2,-1)
gram_matrix_c=torch.bmm(f_c/width,f_c.transpose(-2,-1)/height) #此处除以tensor的大小是为了防止累加过程中出现数值上溢
d_c,e_c=torch.linalg.eigh(gram_matrix_c)
selected_channels=(d_c>0.00001).squeeze() #需要筛选掉过小的特征值,否则会出现数值不稳定的情况
d_c=1.0 / torch.sqrt(d_c[...,selected_channels])
d_c=d_c.diag_embed()
whitened=torch.matmul(torch.matmul(e_c[...,selected_channels],d_c),torch.matmul(e_c[...,selected_channels].transpose(-2,-1),f_c))

#Coloring transformation
m_s=torch.mean(f_s,[-2,-1],keepdim=True)
f_s=f_s-m_s
f_s=torch.flatten(f_s,-2,-1)
gram_matrix_s=torch.bmm(f_s/width,f_s.transpose(-2,-1)/height)
d_s,e_s=torch.linalg.eigh(gram_matrix_s)
selected_channels=(d_s>0.00001).squeeze()
d_s=torch.sqrt(d_s[...,selected_channels])
d_s=d_s.diag_embed()
colored=torch.matmul(torch.matmul(e_s[...,selected_channels],d_s),torch.matmul(e_s[...,selected_channels].transpose(-2,-1),whitened))
colored=colored+torch.flatten(m_s,-2,-1)

colored=alpha*colored+(1-alpha)*f_c


return colored.reshape(-1,channels,height,width)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class StyleTransferNetwork():
def __init__(self, encoders, decoders):
self.encoders=encoders
self.decoders=decoders
self.wct=WCT()

def transfer(self, content_img, style_img, alpha):
i=5
while i>0:
i-=1
encoded_c=self.encoders[i](content_img)
encoded_s=self.encoders[i](style_img)
content_img_encoded=self.wct(encoded_c, encoded_s, alpha)
content_img=self.decoders[i](content_img_encoded)
return content_img
1
transfernetwork=StyleTransferNetwork(Encoders, Decoders)
1
2
3
4
style_folder='./figs/styles/'
content_folder='./figs/content/'
style_figs=os.listdir(style_folder)
content_figs=os.listdir(content_folder)
1
2
3
4
5
6
7
8
9
10
11
style_dataset = ImageFolder(root=style_folder, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)),
]))
style_dataloader = DataLoader(style_dataset, batch_size=1)

content_dataset = ImageFolder(root=content_folder, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)),
]))
content_dataloader = DataLoader(content_dataset, batch_size=1)
1
2
3
4
5
res=[]
for content_fig, _ in content_dataloader:
for style_fig, _ in style_dataloader:
img=transfernetwork.transfer(content_fig, style_fig, 0.5)
res.append(img.detach())

在网络训练完成之后,我们使用了四张风格图与四张内容图,一共生成了16张风格迁移后的图像,如下图所示。其中,\(\alpha\)的值被设置为0.5。

image-20210801205439800

参考

  1. 风格迁移综述_Hygge�的博客-CSDN博客_风格迁移现状
  2. Style Transfer | 风格迁移综述 - 知乎 (zhihu.com)
  3. 图像风格迁移(Neural Style)简史 - 知乎 (zhihu.com)
  4. 如何用简单易懂的例子解释格拉姆矩阵/Gram matrix? - 知乎 (zhihu.com)
  5. https://arxiv.org/pdf/1705.08086.pdf
  6. https://arxiv.org/pdf/1703.06868.pdf
  7. https://arxiv.org/pdf/1603.08155.pdf