概述
图像风格迁移是一种用不同风格渲染图像语义内容的图像处理方法,简单地说,就是提供一幅表示风格的图片,然后在尽可能保留图像内容的情况下,将照片转化为该风格。传统非参数的风格迁移方法只能通过提取图像的低层次特征(色彩、纹理等)来进行纹理合成,无法提取图像高层的特征。非参数风格的方法包括基于笔划的渲染、图像类比、图像滤波、纹理合成等。
近年来,随着深度神经网络在其它视觉感知领域(例如物体与人脸识别)的广泛应用,其在风格迁移领域也具有良好的表现。神经网络具有良好的特征提取能力,能够提取到丰富的语义信息,而这些信息正是风格迁移的基础。
相关工作概述
早期使用神经网络做风格迁移的工作主要使用的是基于优化的方法。简单来说,就是输入一张随机噪音构成的底图,通过计算Style Loss和Content Loss,然后通过重复迭代来更改底图,使其风格纹理上与Style Image相似,内容上与原照片相似。Loss Network是一个已经训练好的网络,在迭代过程中锁定网络的参数,只更改底图的内容。这种方法的速度比较慢,需要迭代200-300次才能得到理想的结果。大致框架如下图所示:
为了提高速度,在工作Perceptual Losses for Real-Time Style Transfer and Super-Resolution 中,将风格迁移的过程使用一个自编码器来实现,大致框架如下:
但上述方法仍有比较大的局限性:每个model只能进行一种Style image的迁移,每增加一张新的风格图需要训练一个新的模型。
工作StyleBank: An Explicit Representation for Neural Image Style Transfer 可以在一个模型中实现多种风格迁移。它的思路是使用一个StyleBank参数层,保存每个风格所对应的参数,然后在风格迁移的过程中使用这些参数来控制图像风格,大致过程如下:
如果需要添加新的风格,只需要在模型中添加一组新的参数,对应于新风格的参数。然后,便可以使用增量学习的方法对模型进行训练,即锁定其它参数,只对新风格的参数进行调整。
但是这一工作仍然无法实现任意风格的迁移,因此之后又有了一些相关工作,可以在网络训练完成之后实现任意的风格迁移。下文详细介绍其中的两个工作。
任意风格迁移
下面详细介绍两个能够进行任意风格迁移的工作。它们使用了不同的思路,在网络训练完成之后,便可以向网络中同时输入内容图和风格图,从而得到使用改风格渲染的内容图。
网络结构
模型的整体结构如(a)所示,使用VGG网络作为Encoder,以及一个对称网络作为Decoder,然后训练这个网络以完成图像重建的任务。在DecoderX中,X代表不同的解码器;而在VGG Relu_X_1中,X代表VGG网络不同层的输出。在训练完成之后,便固定它们的参数。
在此之后,作者提出了Single-level stylization和Multi-level stylization两种风格迁移的方式,分别如(b)和(c)所示。multi-level的方式可以获得更高的图片质量。其中,C代表内容图片,S代表风格图片,通过WCT(Whitening & Coloring Transform)操作将两种照片进行融合,从而实现风格迁移。
算法细节
Decoder
Decoder的结构被设计为VGG Encoder的对称结构,使用最近邻上采样的方式对特征图的尺寸进行扩大。为了评估VGG不同网络层输出的特征图对于重建的影响,作者取5个ReLU_X_1层的输出作为Encoder输出的特征图来训练Decoder。
训练过程中的损失函数为: \[
L=||I_o-I_i||_2^2+\lambda||\Phi(I_o)-\Phi(I_i)||_2^2
\] 其中,\(I_o\) 和\(I_i\) 分别代表重建图像和输入图像,\(\Phi(\cdot)\) 代表VGG编码器ReLU_X_1层的输出。\(\lambda\) 为控制损失函数两部分的权重参数。
作者使用MS-COCO数据集来训练模型,训练过程中无需使用风格图片。
WCT
设内容图片C和风格图片S在经过Decode操作之后的输出分别为\(f_C\in \mathbb{R}^{C\times H_C\times W_C}\) 和\(f_S\in \mathbb{R}^{C\times H_C\times W_C}\) .WCT操作的目的是对\(f_C\) 进行转换,使其匹配\(f_S\) 的协方差矩阵(即Gram矩阵)。WCT操作分为两步,即Whitening和Coloring transform。
在Whitening操作之前,需要对\(f_C\) 做中心化操作,即对\(f_C\) 减去它的平均向量\(m_C\) 。之后做如下的线性变换: \[
\hat{f}_C=E_C D_C^{-\frac{1}{2}}E_C^Tf_C
\] 其中,\(D_C\) 和\(E_C\) 来源于协方差矩阵\(f_Cf_C^T\) 的特征值分解:\(f_Cf_C^T=E_CD_CE_C^T\) 。
作者通过分析将Whitening后的特征输入到Decoder中得到的图片,说明Whitening操作的目的主要是保留图片中的内容特征(即图片包含的物体信息),而删除图片中的风格特征(即颜色等细节)
同样地,Coloring Transform操作同样需要先对\(f_S\) 做中心化操作\(f_S=f_S-m_S\) 。然后做如下的矩阵运算: \[
\hat{f}_{CS}=E_S D_S^{\frac{1}{2}}E_S^T \hat{f}_C
\]
其中,\(D_S\) 和\(E_S\) 来源于协方差矩阵\(f_Sf_S^T\) 的特征值分解:\(f_Sf_S^T=E_S D_S E_S^T\) 。
通过这一变换,\(\hat{f}_{CS}\) 与特征图满足如下关系:\(\hat{f}_{CS}\hat{f}_{CS}^T=f_S f_S^T\) 。
最后,需要再进行\(\hat{f}_{CS}=\hat{f}_{CS}+m_S\) 这一运算,来完成重新中心化操作。
(Histogram matching)
在WCT操作之后,为了控制风格化程度,需要将\(\hat{f}_{CS}\) 与内容\(f_C\) 进行混合: \[
\hat{f}_{CS}=\alpha \hat{f}_{CS} +(1-\alpha)f_C
\] \(\alpha\) 的值越低,代表越倾向保留原始图片的特征;而值越高则越倾向于将原始图片风格化。
小结
这一工作能够在不使用风格图片的基础上对网络进行训练,并在之后实现任意的风格迁移。但是需要训练5个不同大小的自编码器,使得模型的参数量较大。
Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization
Instance Normalization
在之前的风格迁移实现中,Ulyanov等人在BatchNorm的基础上提出了Instance Normalization,使得效果得到了极大提高。它的计算公式如下: \[
\text{IN}(x)=\gamma \left( \frac{x-\mu(x)}{\sigma(x)} \right)+\beta
\] 但是与BatchNorm不同的是,Instance Norm则需要为每一个channel和每个sample都计算独立的\(\mu\) 和\(\sigma\) (BatchNorm只需要为每个channel计算即可): \[
\mu_{nc}(x)=\frac{1}{HW}\sum_{h=1}^H\sum_{w=1}^W x_{nchw} \\
\sigma_{nc}(x)=\sqrt{\frac{1}{HW}\sum_{h=1}^H\sum_{w=1}^W (x_{nchw}-\mu_{nc}(x))^2+\epsilon}
\] 在此基础上,Dumoulin提出了Conditional Instance Normalization操作,为每个风格\(s\) 设定独立的一组\(\gamma\) 和\(\beta\) 的取值: \[
CIN(x;s)=\gamma^s \left( \frac{x-\mu(x)}{\sigma(x)} \right)+\beta^s
\] 为了能够在不重新训练网络的前提下使用新的风格图片,作者提出了Adaptive Instance Normalization(AdaIN)操作: \[
AdaIN(x,y)=\sigma(y) \left( \frac{x-\mu(x)}{\sigma(x)} \right)+\mu(y)
\]
算法细节
网络结构
模型的整体结构比较简单,只包含了一个VGG Encoder和Decoder,并在二者中间加上了一个AdaIN层。Encoder取了VGG网络的前面几层(只到relu4_1),Decoder则取Encoder结构的镜像。为了避免棋盘效应,Decoder中所有的池化操作被替换为最近邻上采样。Decoder中如果使用normalization操作会影响性能,因此去掉了所有相关的层。
损失函数
模型训练的过程只需要调整Decoder的参数即可,损失函数为: \[
L=L_C+\lambda L_S
\] 其中\(L_C\) 代表内容损失,\(L_S\) 代表风格损失,它们的定义如下: \[
L_C=||f(g(t))-t||_2 \\
L_S=\sum_{i=1}^{L}||\mu(\phi_i(g(t)))-\mu(\phi_i(s))||_2+\sum_{i=1}^{L}||\sigma(\phi_i(g(t)))-\sigma(\phi_i(s))||_2
\] 其中,\(f\) 指的是Encoder,\(g\) 指的是Decoder,\(t=AdaIN(f(c),f(s))\) ,\(\phi_i\) 代表VGG-19中用于计算Style loss的那些网络层。
模型的训练过程使用MS-COCO作为内容图片,使用WikiArt作为风格图片进行训练。
小结
这一工作所使用的网络结构比较简单,但是训练过程中除了使用内容图片之外,还需要向网络中输入风格图片。但是在实际应用中,风格图片往往不如内容图片容易搜集。
程序示例
下面的程序是Universal Style Transfer via Feature Transforms这一工作的代码实现,由于整体结构中包含了5个大小不同的Autoencoder,而且要用COCO数据集进行训练,需要大量的计算资源。因此在下面的实现中,我们从https://github.com/pietrocarbo/deep-transfer中下载已经训练好的权重。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 import osimport numpy as npfrom PIL import Imageimport matplotlib.pyplot as pltimport torchimport torchvisionimport torch.nn as nnimport torch.linalgfrom torchvision.datasets import ImageFolderfrom torch.utils.data import DataLoaderfrom models.autoencoder_vgg19.vgg19_1.vgg_normalised_conv1_1 import vgg_normalised_conv1_1from models.autoencoder_vgg19.vgg19_2.vgg_normalised_conv2_1 import vgg_normalised_conv2_1from models.autoencoder_vgg19.vgg19_3.vgg_normalised_conv3_1 import vgg_normalised_conv3_1from models.autoencoder_vgg19.vgg19_4.vgg_normalised_conv4_1 import vgg_normalised_conv4_1from models.autoencoder_vgg19.vgg19_5.vgg_normalised_conv5_1 import vgg_normalised_conv5_1from models.autoencoder_vgg19.vgg19_1.feature_invertor_conv1_1 import feature_invertor_conv1_1from models.autoencoder_vgg19.vgg19_2.feature_invertor_conv2_1 import feature_invertor_conv2_1from models.autoencoder_vgg19.vgg19_3.feature_invertor_conv3_1 import feature_invertor_conv3_1from models.autoencoder_vgg19.vgg19_4.feature_invertor_conv4_1 import feature_invertor_conv4_1from models.autoencoder_vgg19.vgg19_5.feature_invertor_conv5_1 import feature_invertor_conv5_1
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class Encoder1 (nn.Module): def __init__ (self ): super (Encoder1,self).__init__() self.encoder=vgg_normalised_conv1_1 self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_1/vgg_normalised_conv1_1.pth' )) def forward (self,x ): return self.encoder(x) class Decoder1 (nn.Module): def __init__ (self ): super (Decoder1,self).__init__() self.decoder=feature_invertor_conv1_1 self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_1/feature_invertor_conv1_1.pth' )) def forward (self,x ): return self.decoder(x)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class Encoder2 (nn.Module): def __init__ (self ): super (Encoder2,self).__init__() self.encoder=vgg_normalised_conv2_1 self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_2/vgg_normalised_conv2_1.pth' )) def forward (self,x ): return self.encoder(x) class Decoder2 (nn.Module): def __init__ (self ): super (Decoder2,self).__init__() self.decoder=feature_invertor_conv2_1 self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_2/feature_invertor_conv2_1.pth' )) def forward (self,x ): return self.decoder(x)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class Encoder3 (nn.Module): def __init__ (self ): super (Encoder3,self).__init__() self.encoder=vgg_normalised_conv3_1 self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_3/vgg_normalised_conv3_1.pth' )) def forward (self,x ): return self.encoder(x) class Decoder3 (nn.Module): def __init__ (self ): super (Decoder3,self).__init__() self.decoder=feature_invertor_conv3_1 self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_3/feature_invertor_conv3_1.pth' )) def forward (self,x ): return self.decoder(x)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class Encoder4 (nn.Module): def __init__ (self ): super (Encoder4,self).__init__() self.encoder=vgg_normalised_conv4_1 self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_4/vgg_normalised_conv4_1.pth' )) def forward (self,x ): return self.encoder(x) class Decoder4 (nn.Module): def __init__ (self ): super (Decoder4,self).__init__() self.decoder=feature_invertor_conv4_1 self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_4/feature_invertor_conv4_1.pth' )) def forward (self,x ): return self.decoder(x)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class Encoder5 (nn.Module): def __init__ (self ): super (Encoder5,self).__init__() self.encoder=vgg_normalised_conv5_1 self.encoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_5/vgg_normalised_conv5_1.pth' )) def forward (self,x ): return self.encoder(x) class Decoder5 (nn.Module): def __init__ (self ): super (Decoder5,self).__init__() self.decoder=feature_invertor_conv5_1 self.decoder.load_state_dict(torch.load('models/autoencoder_vgg19/vgg19_5/feature_invertor_conv5_1.pth' )) def forward (self,x ): return self.decoder(x)
1 2 Encoders=[Encoder1(),Encoder2(),Encoder3(),Encoder4(),Encoder5()] Decoders=[Decoder1(),Decoder2(),Decoder3(),Decoder4(),Decoder5()]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 class WCT (nn.Module): def __init (self ): super (WCT,self).__init__() def forward (self, f_c, f_s, alpha ): channels=f_c.shape[-3 ] width=f_c.shape[-1 ] height=f_c.shape[-2 ] m_c=torch.mean(f_c,[-2 ,-1 ],keepdim=True ) f_c=f_c-m_c f_c=torch.flatten(f_c,-2 ,-1 ) gram_matrix_c=torch.bmm(f_c/width,f_c.transpose(-2 ,-1 )/height) d_c,e_c=torch.linalg.eigh(gram_matrix_c) selected_channels=(d_c>0.00001 ).squeeze() d_c=1.0 / torch.sqrt(d_c[...,selected_channels]) d_c=d_c.diag_embed() whitened=torch.matmul(torch.matmul(e_c[...,selected_channels],d_c),torch.matmul(e_c[...,selected_channels].transpose(-2 ,-1 ),f_c)) m_s=torch.mean(f_s,[-2 ,-1 ],keepdim=True ) f_s=f_s-m_s f_s=torch.flatten(f_s,-2 ,-1 ) gram_matrix_s=torch.bmm(f_s/width,f_s.transpose(-2 ,-1 )/height) d_s,e_s=torch.linalg.eigh(gram_matrix_s) selected_channels=(d_s>0.00001 ).squeeze() d_s=torch.sqrt(d_s[...,selected_channels]) d_s=d_s.diag_embed() colored=torch.matmul(torch.matmul(e_s[...,selected_channels],d_s),torch.matmul(e_s[...,selected_channels].transpose(-2 ,-1 ),whitened)) colored=colored+torch.flatten(m_s,-2 ,-1 ) colored=alpha*colored+(1 -alpha)*f_c return colored.reshape(-1 ,channels,height,width)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class StyleTransferNetwork (): def __init__ (self, encoders, decoders ): self.encoders=encoders self.decoders=decoders self.wct=WCT() def transfer (self, content_img, style_img, alpha ): i=5 while i>0 : i-=1 encoded_c=self.encoders[i](content_img) encoded_s=self.encoders[i](style_img) content_img_encoded=self.wct(encoded_c, encoded_s, alpha) content_img=self.decoders[i](content_img_encoded) return content_img
1 transfernetwork=StyleTransferNetwork(Encoders, Decoders)
1 2 3 4 style_folder='./figs/styles/' content_folder='./figs/content/' style_figs=os.listdir(style_folder) content_figs=os.listdir(content_folder)
1 2 3 4 5 6 7 8 9 10 11 style_dataset = ImageFolder(root=style_folder, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5 , 0.5 , 0.5 ),(0.5 , 0.5 , 0.5 )), ])) style_dataloader = DataLoader(style_dataset, batch_size=1 ) content_dataset = ImageFolder(root=content_folder, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5 , 0.5 , 0.5 ),(0.5 , 0.5 , 0.5 )), ])) content_dataloader = DataLoader(content_dataset, batch_size=1 )
1 2 3 4 5 res=[] for content_fig, _ in content_dataloader: for style_fig, _ in style_dataloader: img=transfernetwork.transfer(content_fig, style_fig, 0.5 ) res.append(img.detach())
在网络训练完成之后,我们使用了四张风格图与四张内容图,一共生成了16张风格迁移后的图像,如下图所示。其中,\(\alpha\) 的值被设置为0.5。
image-20210801205439800
参考
风格迁移综述_Hygge�的博客-CSDN博客_风格迁移现状
Style Transfer | 风格迁移综述 - 知乎 (zhihu.com)
图像风格迁移(Neural Style)简史 - 知乎 (zhihu.com)
如何用简单易懂的例子解释格拉姆矩阵/Gram matrix? - 知乎 (zhihu.com)
https://arxiv.org/pdf/1705.08086.pdf
https://arxiv.org/pdf/1703.06868.pdf
https://arxiv.org/pdf/1603.08155.pdf