Neural Style Transfer without pystiche

This example showcases how a basic Neural Style Transfer (NST), i.e. image-based optimization, could be performed without pystiche.

Note

This is an example how to implement an NST and not a tutorial on how NST works. As such, it will not explain why a specific choice was made or how a component works. If you have never worked with NST before, we strongly suggest you to read the Gist first.

Setup

We start this example by importing everything we need and setting the device we will be working on. torch and torchvision will be used for the actual NST. Furthermore, we use PIL.Image for the file input, and matplotlib.pyplot to show the images.

26 import itertools
27 import os.path
28 from collections import OrderedDict
29 from urllib.request import urlopen
30
31 import matplotlib.pyplot as plt
32 from PIL import Image
33
34 import torch
35 import torchvision
36 from torch import nn, optim
37 from torch.nn.functional import mse_loss
38 from torchvision import transforms
39 from torchvision.models import vgg19
40 from torchvision.transforms.functional import resize
41
42 print(f"I'm working with torch=={torch.__version__}")
43 print(f"I'm working with torchvision=={torchvision.__version__}")
44
45 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 print(f"I'm working with {device}")

Out:

I'm working with torch==1.9.0+cu111
I'm working with torchvision==0.10.0+cu111
I'm working with cuda

The core component of different NSTs is the perceptual loss, which is used as optimization criterion. The perceptual loss is usually, and also for this example, calculated on features maps also called encodings. These encodings are generated from different layers of a Convolutional Neural Net (CNN) also called encoder.

A common implementation strategy for the perceptual loss is to weave in transparent loss layers into the encoder. These loss layers are called transparent since from an outside view they simply pass the input through without alteration. Internally though, they calculate the loss with the encodings of the previous layer and store them in themselves. After the forward pass is completed the stored losses are aggregated and propagated backwards to the image. While this is simple to implement, this practice has two downsides:

  1. The calculated score is part of the current state but has to be stored inside the layer. This is generally not recommended.

  2. While the encoder is a part of the perceptual loss, it itself does not generate it. One should be able to use the same encoder with a different perceptual loss without modification.

Thus, this example (and pystiche) follows a different approach and separates the encoder and the perceptual loss into individual entities.

Multi-layer Encoder

In a first step we define a MultiLayerEncoder that should have the following properties:

  1. Given an image and a set of layers, the MultiLayerEncoder should return the encodings of every given layer.

  2. Since the encodings have to be generated in every optimization step they should be calculated in a single forward pass to keep the processing costs low.

  3. To reduce the static memory requirement, the MultiLayerEncoder should be trim mable in order to remove unused layers.

We achieve the main functionality by subclassing torch.nn.Sequential and define a custom forward method, i.e. different behavior if called. Besides the image it also takes an iterable layer_cfgs containing multiple sequences of layers. In the method body we first find the deepest_layer that was requested. Subsequently, we calculate and store all encodings of the image up to that layer. Finally we can return all requested encodings without processing the same layer twice.

 96 class MultiLayerEncoder(nn.Sequential):
 97     def forward(self, image, *layer_cfgs):
 98         storage = {}
 99         deepest_layer = self._find_deepest_layer(*layer_cfgs)
100         for layer, module in self.named_children():
101             image = storage[layer] = module(image)
102             if layer == deepest_layer:
103                 break
104
105         return [[storage[layer] for layer in layers] for layers in layer_cfgs]
106
107     def children_names(self):
108         for name, module in self.named_children():
109             yield name
110
111     def _find_deepest_layer(self, *layer_cfgs):
112         # find all unique requested layers
113         req_layers = set(itertools.chain(*layer_cfgs))
114         try:
115             # find the deepest requested layer by indexing the layers within
116             # the multi layer encoder
117             children_names = list(self.children_names())
118             return sorted(req_layers, key=children_names.index)[-1]
119         except ValueError as error:
120             layer = str(error).split()[0]
121         raise ValueError(f"Layer {layer} is not part of the multi-layer encoder.")
122
123     def trim(self, *layer_cfgs):
124         deepest_layer = self._find_deepest_layer(*layer_cfgs)
125         children_names = list(self.children_names())
126         del self[children_names.index(deepest_layer) + 1 :]

The pretrained models the MultiLayerEncoder is based on are usually trained on preprocessed images. In PyTorch all models expect images are normalized by a per-channel mean = (0.485, 0.456, 0.406) and standard deviation (std = (0.229, 0.224, 0.225)). To include this into a, MultiLayerEncoder, we implement this as torch.nn.Module .

138 class Normalize(nn.Module):
139     def __init__(self, mean, std):
140         super().__init__()
141         self.register_buffer("mean", torch.tensor(mean).view(1, -1, 1, 1))
142         self.register_buffer("std", torch.tensor(std).view(1, -1, 1, 1))
143
144     def forward(self, image):
145         return (image - self.mean) / self.std
146
147
148 class TorchNormalize(Normalize):
149     def __init__(self):
150         super().__init__((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

In a last step we need to specify the structure of the MultiLayerEncoder. In this example we use a VGGMultiLayerEncoder based on the VGG19 CNN introduced by Simonyan and Zisserman [SZ14].

We only include the feature extraction stage (vgg_net.features), i.e. the convolutional stage, since the classifier stage (vgg_net.classifier) only accepts feature maps of a single size.

For our convenience we rename the layers in the same scheme the authors used instead of keeping the consecutive index of a default torch.nn.Sequential. The first layer however is the TorchNormalize as defined above.

167 class VGGMultiLayerEncoder(MultiLayerEncoder):
168     def __init__(self, vgg_net):
169         modules = OrderedDict((("preprocessing", TorchNormalize()),))
170
171         block = depth = 1
172         for module in vgg_net.features.children():
173             if isinstance(module, nn.Conv2d):
174                 layer = f"conv{block}_{depth}"
175             elif isinstance(module, nn.BatchNorm2d):
176                 layer = f"bn{block}_{depth}"
177             elif isinstance(module, nn.ReLU):
178                 # without inplace=False the encodings of the previous layer would no
179                 # longer be accessible after the ReLU layer is executed
180                 module = nn.ReLU(inplace=False)
181                 layer = f"relu{block}_{depth}"
182                 # each ReLU layer increases the depth of the current block by one
183                 depth += 1
184             elif isinstance(module, nn.MaxPool2d):
185                 layer = f"pool{block}"
186                 # each max pooling layer marks the end of the current block
187                 block += 1
188                 depth = 1
189             else:
190                 msg = f"Type {type(module)} is not part of the VGG architecture."
191                 raise RuntimeError(msg)
192
193             modules[layer] = module
194
195         super().__init__(modules)
196
197
198 def vgg19_multi_layer_encoder():
199     return VGGMultiLayerEncoder(vgg19(pretrained=True))
200
201
202 multi_layer_encoder = vgg19_multi_layer_encoder().to(device)
203 print(multi_layer_encoder)

Out:

VGGMultiLayerEncoder(
  (preprocessing): TorchNormalize()
  (conv1_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_1): ReLU()
  (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1_2): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_1): ReLU()
  (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2_2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_1): ReLU()
  (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_2): ReLU()
  (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_3): ReLU()
  (conv3_4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3_4): ReLU()
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu4_1): ReLU()
  (conv4_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu4_2): ReLU()
  (conv4_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu4_3): ReLU()
  (conv4_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu4_4): ReLU()
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv5_1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu5_1): ReLU()
  (conv5_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu5_2): ReLU()
  (conv5_3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu5_3): ReLU()
  (conv5_4): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu5_4): ReLU()
  (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

Perceptual Loss

In order to calculate the perceptual loss, i.e. the optimization criterion, we define a MultiLayerLoss to have a convenient interface. This will be subclassed later by the ContentLoss and StyleLoss.

If called with a sequence of ìnput_encs the MultiLayerLoss should calculate layerwise scores together with the corresponding target_encs. For that a MultiLayerLoss needs the ability to store the target_encs so that they can be reused for every call. The individual layer scores should be averaged by the number of encodings and finally weighted by a score_weight.

To achieve this we subclass torch.nn.Module . The target_encs are stored as buffers, since they are not trainable parameters. The actual functionality has to be defined in calculate_score by a subclass.

225 def mean(sized):
226     return sum(sized) / len(sized)
227
228
229 class MultiLayerLoss(nn.Module):
230     def __init__(self, score_weight=1e0):
231         super().__init__()
232         self.score_weight = score_weight
233         self._numel_target_encs = 0
234
235     def _target_enc_name(self, idx):
236         return f"_target_encs_{idx}"
237
238     def set_target_encs(self, target_encs):
239         self._numel_target_encs = len(target_encs)
240         for idx, enc in enumerate(target_encs):
241             self.register_buffer(self._target_enc_name(idx), enc.detach())
242
243     @property
244     def target_encs(self):
245         return tuple(
246             getattr(self, self._target_enc_name(idx))
247             for idx in range(self._numel_target_encs)
248         )
249
250     def forward(self, input_encs):
251         if len(input_encs) != self._numel_target_encs:
252             msg = (
253                 f"The number of given input encodings and stored target encodings "
254                 f"does not match: {len(input_encs)} != {self._numel_target_encs}"
255             )
256             raise RuntimeError(msg)
257
258         layer_losses = [
259             self.calculate_score(input, target)
260             for input, target in zip(input_encs, self.target_encs)
261         ]
262         return mean(layer_losses) * self.score_weight
263
264     def calculate_score(self, input, target):
265         raise NotImplementedError

In this example we use the feature_reconstruction_loss introduced by Mahendran and Vedaldi [MV15] as ContentLoss as well as the gram_loss introduced by Gatys, Ecker, and Bethge [GEB16] as StyleLoss.

274 def feature_reconstruction_loss(input, target):
275     return mse_loss(input, target)
276
277
278 class ContentLoss(MultiLayerLoss):
279     def calculate_score(self, input, target):
280         return feature_reconstruction_loss(input, target)
281
282
283 def channelwise_gram_matrix(x, normalize=True):
284     x = torch.flatten(x, 2)
285     G = torch.bmm(x, x.transpose(1, 2))
286     if normalize:
287         return G / x.size()[-1]
288     else:
289         return G
290
291
292 def gram_loss(input, target):
293     return mse_loss(channelwise_gram_matrix(input), channelwise_gram_matrix(target))
294
295
296 class StyleLoss(MultiLayerLoss):
297     def calculate_score(self, input, target):
298         return gram_loss(input, target)

Images

Before we can load the content and style image, we need to define some basic I/O utilities.

At import a fake batch dimension is added to the images to be able to pass it through the MultiLayerEncoder without further modification. This dimension is removed again upon export. Furthermore, all images will be resized to size=500 pixels.

312 import_from_pil = transforms.Compose(
313     (
314         transforms.ToTensor(),
315         transforms.Lambda(lambda x: x.unsqueeze(0)),
316         transforms.Lambda(lambda x: x.to(device)),
317     )
318 )
319
320 export_to_pil = transforms.Compose(
321     (
322         transforms.Lambda(lambda x: x.cpu()),
323         transforms.Lambda(lambda x: x.squeeze(0)),
324         transforms.Lambda(lambda x: x.clamp(0.0, 1.0)),
325         transforms.ToPILImage(),
326     )
327 )
328
329
330 def download_image(url):
331     file = os.path.abspath(os.path.basename(url))
332     with open(file, "wb") as fh, urlopen(url) as response:
333         fh.write(response.read())
334
335     return file
336
337
338 def read_image(file, size=500):
339     image = Image.open(file)
340     image = resize(image, size)
341     return import_from_pil(image)
342
343
344 def show_image(image, title=None):
345     _, ax = plt.subplots()
346     ax.axis("off")
347     if title is not None:
348         ax.set_title(title)
349
350     image = export_to_pil(image)
351     ax.imshow(image)

With the I/O utilities set up, we now download, read, and show the images that will be used in the NST.

Note

The images used in this example are licensed under the permissive Pixabay License .

366 content_url = "https://download.pystiche.org/images/bird1.jpg"
367 content_file = download_image(content_url)
368 content_image = read_image(content_file)
369 show_image(content_image, title="Content image")
Content image
374 style_url = "https://download.pystiche.org/images/paint.jpg"
375 style_file = download_image(style_url)
376 style_image = read_image(style_file)
377 show_image(style_image, title="Style image")
Style image

Neural Style Transfer

At first we chose the content_layers and style_layers on which the encodings are compared. With them we trim the multi_layer_encoder to remove unused layers that otherwise occupy memory.

Afterwards we calculate the target content and style encodings. The calculation is performed without a gradient since the gradient of the target encodings is not needed for the optimization.

392 content_layers = ("relu4_2",)
393 style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")
394
395 multi_layer_encoder.trim(content_layers, style_layers)
396
397 with torch.no_grad():
398     target_content_encs = multi_layer_encoder(content_image, content_layers)[0]
399     target_style_encs = multi_layer_encoder(style_image, style_layers)[0]

Next up, we instantiate the ContentLoss and StyleLoss with a corresponding weight. Afterwards we store the previously calculated target encodings.

406 content_weight = 1e0
407 content_loss = ContentLoss(score_weight=content_weight)
408 content_loss.set_target_encs(target_content_encs)
409
410 style_weight = 1e3
411 style_loss = StyleLoss(score_weight=style_weight)
412 style_loss.set_target_encs(target_style_encs)

We start NST from the content_image since this way it converges quickly.

418 input_image = content_image.clone()
419 show_image(input_image, "Input image")
Input image

Note

If you want to start from a white noise image instead use

input_image = torch.rand_like(content_image)

In a last preliminary step we create the optimizer that will be performing the NST. Since we want to adapt the pixels of the input_image directly, we pass it as optimization parameters.

437 optimizer = optim.LBFGS([input_image.requires_grad_(True)], max_iter=1)

Finally we run the NST. The loss calculation has to happen inside a closure since the LBFGS optimizer could need to reevaluate it multiple times per optimization step . This structure is also valid for all other optimizers.

446 num_steps = 500
447
448 for step in range(1, num_steps + 1):
449
450     def closure():
451         optimizer.zero_grad()
452
453         input_encs = multi_layer_encoder(input_image, content_layers, style_layers)
454         input_content_encs, input_style_encs = input_encs
455
456         content_score = content_loss(input_content_encs)
457         style_score = style_loss(input_style_encs)
458
459         perceptual_loss = content_score + style_score
460         perceptual_loss.backward()
461
462         if step % 50 == 0:
463             print(f"Step {step}")
464             print(f"Content loss: {content_score.item():.3e}")
465             print(f"Style loss:   {style_score.item():.3e}")
466             print("-----------------------")
467
468         return perceptual_loss
469
470     optimizer.step(closure)
471
472 output_image = input_image.detach()

Out:

Step 50
Content loss: 6.720e+00
Style loss:   9.722e+01
-----------------------
Step 100
Content loss: 6.602e+00
Style loss:   3.271e+01
-----------------------
Step 150
Content loss: 6.396e+00
Style loss:   2.069e+01
-----------------------
Step 200
Content loss: 6.201e+00
Style loss:   1.480e+01
-----------------------
Step 250
Content loss: 6.050e+00
Style loss:   1.145e+01
-----------------------
Step 300
Content loss: 5.908e+00
Style loss:   9.339e+00
-----------------------
Step 350
Content loss: 5.804e+00
Style loss:   7.779e+00
-----------------------
Step 400
Content loss: 5.719e+00
Style loss:   6.387e+00
-----------------------
Step 450
Content loss: 5.651e+00
Style loss:   5.274e+00
-----------------------
Step 500
Content loss: 5.584e+00
Style loss:   4.520e+00
-----------------------

After the NST we show the resulting image.

477 show_image(output_image, title="Output image")
Output image

Conclusion

As hopefully has become clear, an NST requires even in its simplest form quite a lot of utilities and boilerplate code. This makes it hard to maintain and keep bug free as it is easy to lose track of everything.

Judging by the lines of code one could (falsely) conclude that the actual NST is just an appendix. If you feel the same you can stop worrying now: in Neural Style Transfer with pystiche we showcase how to achieve the same result with pystiche.

Total running time of the script: ( 1 minutes 35.264 seconds)

Estimated memory usage: 2478 MB

Gallery generated by Sphinx-Gallery