Note
Click here to download the full example code
Neural Style Transfer without pystiche
¶
This example showcases how a basic Neural Style Transfer (NST), i.e. image-based
optimization, could be performed without pystiche
.
Note
This is an example how to implement an NST and not a tutorial on how NST works. As such, it will not explain why a specific choice was made or how a component works. If you have never worked with NST before, we strongly suggest you to read the Gist first.
Setup¶
We start this example by importing everything we need and setting the device we will
be working on. torch
and torchvision
will be used for the actual NST.
Furthermore, we use PIL.Image
for the file input, and matplotlib.pyplot
to show the images.
26 import itertools
27 import os.path
28 from collections import OrderedDict
29 from urllib.request import urlopen
30
31 import matplotlib.pyplot as plt
32 from PIL import Image
33
34 import torch
35 import torchvision
36 from torch import nn, optim
37 from torch.nn.functional import mse_loss
38 from torchvision import transforms
39 from torchvision.models import vgg19
40 from torchvision.transforms.functional import resize
41
42 print(f"I'm working with torch=={torch.__version__}")
43 print(f"I'm working with torchvision=={torchvision.__version__}")
44
45 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 print(f"I'm working with {device}")
Out:
I'm working with torch==1.9.0+cu111
I'm working with torchvision==0.10.0+cu111
I'm working with cuda
The core component of different NSTs is the perceptual loss, which is used as optimization criterion. The perceptual loss is usually, and also for this example, calculated on features maps also called encodings. These encodings are generated from different layers of a Convolutional Neural Net (CNN) also called encoder.
A common implementation strategy for the perceptual loss is to weave in transparent loss layers into the encoder. These loss layers are called transparent since from an outside view they simply pass the input through without alteration. Internally though, they calculate the loss with the encodings of the previous layer and store them in themselves. After the forward pass is completed the stored losses are aggregated and propagated backwards to the image. While this is simple to implement, this practice has two downsides:
The calculated score is part of the current state but has to be stored inside the layer. This is generally not recommended.
While the encoder is a part of the perceptual loss, it itself does not generate it. One should be able to use the same encoder with a different perceptual loss without modification.
Thus, this example (and pystiche
) follows a different approach and separates the
encoder and the perceptual loss into individual entities.
Multi-layer Encoder¶
In a first step we define a MultiLayerEncoder
that should have the following
properties:
Given an image and a set of layers, the
MultiLayerEncoder
should return the encodings of every given layer.Since the encodings have to be generated in every optimization step they should be calculated in a single forward pass to keep the processing costs low.
To reduce the static memory requirement, the
MultiLayerEncoder
should betrim
mable in order to remove unused layers.
We achieve the main functionality by subclassing torch.nn.Sequential
and
define a custom forward
method, i.e. different behavior if called. Besides the
image it also takes an iterable layer_cfgs
containing multiple sequences of
layers
. In the method body we first find the deepest_layer
that was
requested. Subsequently, we calculate and store all encodings of the image
up to
that layer. Finally we can return all requested encodings without processing the same
layer twice.
96 class MultiLayerEncoder(nn.Sequential):
97 def forward(self, image, *layer_cfgs):
98 storage = {}
99 deepest_layer = self._find_deepest_layer(*layer_cfgs)
100 for layer, module in self.named_children():
101 image = storage[layer] = module(image)
102 if layer == deepest_layer:
103 break
104
105 return [[storage[layer] for layer in layers] for layers in layer_cfgs]
106
107 def children_names(self):
108 for name, module in self.named_children():
109 yield name
110
111 def _find_deepest_layer(self, *layer_cfgs):
112 # find all unique requested layers
113 req_layers = set(itertools.chain(*layer_cfgs))
114 try:
115 # find the deepest requested layer by indexing the layers within
116 # the multi layer encoder
117 children_names = list(self.children_names())
118 return sorted(req_layers, key=children_names.index)[-1]
119 except ValueError as error:
120 layer = str(error).split()[0]
121 raise ValueError(f"Layer {layer} is not part of the multi-layer encoder.")
122
123 def trim(self, *layer_cfgs):
124 deepest_layer = self._find_deepest_layer(*layer_cfgs)
125 children_names = list(self.children_names())
126 del self[children_names.index(deepest_layer) + 1 :]
The pretrained models the MultiLayerEncoder
is based on are usually trained on
preprocessed images. In PyTorch all models expect images are
normalized by a
per-channel mean = (0.485, 0.456, 0.406)
and standard deviation
(std = (0.229, 0.224, 0.225)
). To include this into a, MultiLayerEncoder
, we
implement this as torch.nn.Module
.
138 class Normalize(nn.Module):
139 def __init__(self, mean, std):
140 super().__init__()
141 self.register_buffer("mean", torch.tensor(mean).view(1, -1, 1, 1))
142 self.register_buffer("std", torch.tensor(std).view(1, -1, 1, 1))
143
144 def forward(self, image):
145 return (image - self.mean) / self.std
146
147
148 class TorchNormalize(Normalize):
149 def __init__(self):
150 super().__init__((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
In a last step we need to specify the structure of the MultiLayerEncoder
. In this
example we use a VGGMultiLayerEncoder
based on the VGG19
CNN introduced by
Simonyan and Zisserman [SZ14].
We only include the feature extraction stage (vgg_net.features
), i.e. the
convolutional stage, since the classifier stage (vgg_net.classifier
) only accepts
feature maps of a single size.
For our convenience we rename the layers in the same scheme the authors used instead
of keeping the consecutive index of a default torch.nn.Sequential
. The first
layer however is the TorchNormalize
as defined above.
167 class VGGMultiLayerEncoder(MultiLayerEncoder):
168 def __init__(self, vgg_net):
169 modules = OrderedDict((("preprocessing", TorchNormalize()),))
170
171 block = depth = 1
172 for module in vgg_net.features.children():
173 if isinstance(module, nn.Conv2d):
174 layer = f"conv{block}_{depth}"
175 elif isinstance(module, nn.BatchNorm2d):
176 layer = f"bn{block}_{depth}"
177 elif isinstance(module, nn.ReLU):
178 # without inplace=False the encodings of the previous layer would no
179 # longer be accessible after the ReLU layer is executed
180 module = nn.ReLU(inplace=False)
181 layer = f"relu{block}_{depth}"
182 # each ReLU layer increases the depth of the current block by one
183 depth += 1
184 elif isinstance(module, nn.MaxPool2d):
185 layer = f"pool{block}"
186 # each max pooling layer marks the end of the current block
187 block += 1
188 depth = 1
189 else:
190 msg = f"Type {type(module)} is not part of the VGG architecture."
191 raise RuntimeError(msg)
192
193 modules[layer] = module
194
195 super().__init__(modules)
196
197
198 def vgg19_multi_layer_encoder():
199 return VGGMultiLayerEncoder(vgg19(pretrained=True))
200
201
202 multi_layer_encoder = vgg19_multi_layer_encoder().to(device)
203 print(multi_layer_encoder)
Out:
VGGMultiLayerEncoder(
(preprocessing): TorchNormalize()
(conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1_1): ReLU()
(conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1_2): ReLU()
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2_1): ReLU()
(conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2_2): ReLU()
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3_1): ReLU()
(conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3_2): ReLU()
(conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3_3): ReLU()
(conv3_4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3_4): ReLU()
(pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4_1): ReLU()
(conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4_2): ReLU()
(conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4_3): ReLU()
(conv4_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4_4): ReLU()
(pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5_1): ReLU()
(conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5_2): ReLU()
(conv5_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5_3): ReLU()
(conv5_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5_4): ReLU()
(pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
Perceptual Loss¶
In order to calculate the perceptual loss, i.e. the optimization criterion, we define
a MultiLayerLoss
to have a convenient interface. This will be subclassed later by
the ContentLoss
and StyleLoss
.
If called with a sequence of ìnput_encs
the MultiLayerLoss
should calculate
layerwise scores together with the corresponding target_encs
. For that a
MultiLayerLoss
needs the ability to store the target_encs
so that they can be
reused for every call. The individual layer scores should be averaged by the number
of encodings and finally weighted by a score_weight
.
To achieve this we subclass torch.nn.Module
. The target_encs
are stored
as buffers, since they are not trainable parameters. The actual functionality has to
be defined in calculate_score
by a subclass.
225 def mean(sized):
226 return sum(sized) / len(sized)
227
228
229 class MultiLayerLoss(nn.Module):
230 def __init__(self, score_weight=1e0):
231 super().__init__()
232 self.score_weight = score_weight
233 self._numel_target_encs = 0
234
235 def _target_enc_name(self, idx):
236 return f"_target_encs_{idx}"
237
238 def set_target_encs(self, target_encs):
239 self._numel_target_encs = len(target_encs)
240 for idx, enc in enumerate(target_encs):
241 self.register_buffer(self._target_enc_name(idx), enc.detach())
242
243 @property
244 def target_encs(self):
245 return tuple(
246 getattr(self, self._target_enc_name(idx))
247 for idx in range(self._numel_target_encs)
248 )
249
250 def forward(self, input_encs):
251 if len(input_encs) != self._numel_target_encs:
252 msg = (
253 f"The number of given input encodings and stored target encodings "
254 f"does not match: {len(input_encs)} != {self._numel_target_encs}"
255 )
256 raise RuntimeError(msg)
257
258 layer_losses = [
259 self.calculate_score(input, target)
260 for input, target in zip(input_encs, self.target_encs)
261 ]
262 return mean(layer_losses) * self.score_weight
263
264 def calculate_score(self, input, target):
265 raise NotImplementedError
In this example we use the feature_reconstruction_loss
introduced by Mahendran
and Vedaldi [MV15] as ContentLoss
as well as the gram_loss
introduced
by Gatys, Ecker, and Bethge [GEB16] as StyleLoss
.
274 def feature_reconstruction_loss(input, target):
275 return mse_loss(input, target)
276
277
278 class ContentLoss(MultiLayerLoss):
279 def calculate_score(self, input, target):
280 return feature_reconstruction_loss(input, target)
281
282
283 def channelwise_gram_matrix(x, normalize=True):
284 x = torch.flatten(x, 2)
285 G = torch.bmm(x, x.transpose(1, 2))
286 if normalize:
287 return G / x.size()[-1]
288 else:
289 return G
290
291
292 def gram_loss(input, target):
293 return mse_loss(channelwise_gram_matrix(input), channelwise_gram_matrix(target))
294
295
296 class StyleLoss(MultiLayerLoss):
297 def calculate_score(self, input, target):
298 return gram_loss(input, target)
Images¶
Before we can load the content and style image, we need to define some basic I/O utilities.
At import a fake batch dimension is added to the images to be able to pass it through
the MultiLayerEncoder
without further modification. This dimension is removed
again upon export. Furthermore, all images will be resized to size=500
pixels.
312 import_from_pil = transforms.Compose(
313 (
314 transforms.ToTensor(),
315 transforms.Lambda(lambda x: x.unsqueeze(0)),
316 transforms.Lambda(lambda x: x.to(device)),
317 )
318 )
319
320 export_to_pil = transforms.Compose(
321 (
322 transforms.Lambda(lambda x: x.cpu()),
323 transforms.Lambda(lambda x: x.squeeze(0)),
324 transforms.Lambda(lambda x: x.clamp(0.0, 1.0)),
325 transforms.ToPILImage(),
326 )
327 )
328
329
330 def download_image(url):
331 file = os.path.abspath(os.path.basename(url))
332 with open(file, "wb") as fh, urlopen(url) as response:
333 fh.write(response.read())
334
335 return file
336
337
338 def read_image(file, size=500):
339 image = Image.open(file)
340 image = resize(image, size)
341 return import_from_pil(image)
342
343
344 def show_image(image, title=None):
345 _, ax = plt.subplots()
346 ax.axis("off")
347 if title is not None:
348 ax.set_title(title)
349
350 image = export_to_pil(image)
351 ax.imshow(image)
With the I/O utilities set up, we now download, read, and show the images that will be used in the NST.
Note
The images used in this example are licensed under the permissive Pixabay License .
366 content_url = "https://download.pystiche.org/images/bird1.jpg"
367 content_file = download_image(content_url)
368 content_image = read_image(content_file)
369 show_image(content_image, title="Content image")
374 style_url = "https://download.pystiche.org/images/paint.jpg"
375 style_file = download_image(style_url)
376 style_image = read_image(style_file)
377 show_image(style_image, title="Style image")
Neural Style Transfer¶
At first we chose the content_layers
and style_layers
on which the encodings
are compared. With them we trim
the multi_layer_encoder
to remove
unused layers that otherwise occupy memory.
Afterwards we calculate the target content and style encodings. The calculation is performed without a gradient since the gradient of the target encodings is not needed for the optimization.
392 content_layers = ("relu4_2",)
393 style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")
394
395 multi_layer_encoder.trim(content_layers, style_layers)
396
397 with torch.no_grad():
398 target_content_encs = multi_layer_encoder(content_image, content_layers)[0]
399 target_style_encs = multi_layer_encoder(style_image, style_layers)[0]
Next up, we instantiate the ContentLoss
and StyleLoss
with a corresponding
weight. Afterwards we store the previously calculated target encodings.
406 content_weight = 1e0
407 content_loss = ContentLoss(score_weight=content_weight)
408 content_loss.set_target_encs(target_content_encs)
409
410 style_weight = 1e3
411 style_loss = StyleLoss(score_weight=style_weight)
412 style_loss.set_target_encs(target_style_encs)
We start NST from the content_image
since this way it converges quickly.
418 input_image = content_image.clone()
419 show_image(input_image, "Input image")
Note
If you want to start from a white noise image instead use
input_image = torch.rand_like(content_image)
In a last preliminary step we create the optimizer that will be performing the NST.
Since we want to adapt the pixels of the input_image
directly, we pass it as
optimization parameters.
437 optimizer = optim.LBFGS([input_image.requires_grad_(True)], max_iter=1)
Finally we run the NST. The loss calculation has to happen inside a closure
since the LBFGS
optimizer could need to
reevaluate it multiple times per optimization step
. This structure is also valid for all other optimizers.
446 num_steps = 500
447
448 for step in range(1, num_steps + 1):
449
450 def closure():
451 optimizer.zero_grad()
452
453 input_encs = multi_layer_encoder(input_image, content_layers, style_layers)
454 input_content_encs, input_style_encs = input_encs
455
456 content_score = content_loss(input_content_encs)
457 style_score = style_loss(input_style_encs)
458
459 perceptual_loss = content_score + style_score
460 perceptual_loss.backward()
461
462 if step % 50 == 0:
463 print(f"Step {step}")
464 print(f"Content loss: {content_score.item():.3e}")
465 print(f"Style loss: {style_score.item():.3e}")
466 print("-----------------------")
467
468 return perceptual_loss
469
470 optimizer.step(closure)
471
472 output_image = input_image.detach()
Out:
Step 50
Content loss: 6.720e+00
Style loss: 9.722e+01
-----------------------
Step 100
Content loss: 6.602e+00
Style loss: 3.271e+01
-----------------------
Step 150
Content loss: 6.396e+00
Style loss: 2.069e+01
-----------------------
Step 200
Content loss: 6.201e+00
Style loss: 1.480e+01
-----------------------
Step 250
Content loss: 6.050e+00
Style loss: 1.145e+01
-----------------------
Step 300
Content loss: 5.908e+00
Style loss: 9.339e+00
-----------------------
Step 350
Content loss: 5.804e+00
Style loss: 7.779e+00
-----------------------
Step 400
Content loss: 5.719e+00
Style loss: 6.387e+00
-----------------------
Step 450
Content loss: 5.651e+00
Style loss: 5.274e+00
-----------------------
Step 500
Content loss: 5.584e+00
Style loss: 4.520e+00
-----------------------
After the NST we show the resulting image.
477 show_image(output_image, title="Output image")
Conclusion¶
As hopefully has become clear, an NST requires even in its simplest form quite a lot of utilities and boilerplate code. This makes it hard to maintain and keep bug free as it is easy to lose track of everything.
Judging by the lines of code one could (falsely) conclude that the actual NST is just
an appendix. If you feel the same you can stop worrying now: in
Neural Style Transfer with pystiche we showcase
how to achieve the same result with pystiche
.
Total running time of the script: ( 1 minutes 35.264 seconds)
Estimated memory usage: 2478 MB