Neural Style Transfer without pystiche

This example showcases how a basic Neural Style Transfer (NST), i.e. image-based optimization, could be performed without pystiche.


This is an example how to implement an NST and not a tutorial on how NST works. As such, it will not explain why a specific choice was made or how a component works. If you have never worked with NST before, we strongly suggest you to read the Gist first.


We start this example by importing everything we need and setting the device we will be working on. torch and torchvision will be used for the actual NST. Furthermore, we use PIL.Image for the file input, and matplotlib.pyplot to show the images.

26 import itertools
27 import os.path
28 from collections import OrderedDict
29 from urllib.request import urlopen
31 import matplotlib.pyplot as plt
32 from PIL import Image
33 from import tqdm
35 import torch
36 import torchvision
37 from torch import nn, optim
38 from torch.nn.functional import mse_loss
39 from torchvision import transforms
40 from torchvision.models import vgg19
41 from torchvision.transforms.functional import resize
43 print(f"I'm working with torch=={torch.__version__}")
44 print(f"I'm working with torchvision=={torchvision.__version__}")
46 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47 print(f"I'm working with {device}")

The core component of different NSTs is the perceptual loss, which is used as optimization criterion. The perceptual loss is usually, and also for this example, calculated on features maps also called encodings. These encodings are generated from different layers of a Convolutional Neural Net (CNN) also called encoder.

A common implementation strategy for the perceptual loss is to weave in transparent loss layers into the encoder. These loss layers are called transparent since from an outside view they simply pass the input through without alteration. Internally though, they calculate the loss with the encodings of the previous layer and store them in themselves. After the forward pass is completed the stored losses are aggregated and propagated backwards to the image. While this is simple to implement, this practice has two downsides:

  1. The calculated score is part of the current state but has to be stored inside the layer. This is generally not recommended.

  2. While the encoder is a part of the perceptual loss, it itself does not generate it. One should be able to use the same encoder with a different perceptual loss without modification.

Thus, this example (and pystiche) follows a different approach and separates the encoder and the perceptual loss into individual entities.

Multi-layer Encoder

In a first step we define a MultiLayerEncoder that should have the following properties:

  1. Given an image and a set of layers, the MultiLayerEncoder should return the encodings of every given layer.

  2. Since the encodings have to be generated in every optimization step they should be calculated in a single forward pass to keep the processing costs low.

  3. To reduce the static memory requirement, the MultiLayerEncoder should be trim mable in order to remove unused layers.

We achieve the main functionality by subclassing torch.nn.Sequential and define a custom forward method, i.e. different behavior if called. Besides the image it also takes an iterable layer_cfgs containing multiple sequences of layers. In the method body we first find the deepest_layer that was requested. Subsequently, we calculate and store all encodings of the image up to that layer. Finally we can return all requested encodings without processing the same layer twice.

 97 class MultiLayerEncoder(nn.Sequential):
 98     def forward(self, image, *layer_cfgs):
 99         storage = {}
100         deepest_layer = self._find_deepest_layer(*layer_cfgs)
101         for layer, module in self.named_children():
102             image = storage[layer] = module(image)
103             if layer == deepest_layer:
104                 break
106         return [[storage[layer] for layer in layers] for layers in layer_cfgs]
108     def children_names(self):
109         for name, module in self.named_children():
110             yield name
112     def _find_deepest_layer(self, *layer_cfgs):
113         # find all unique requested layers
114         req_layers = set(itertools.chain(*layer_cfgs))
115         try:
116             # find the deepest requested layer by indexing the layers within
117             # the multi layer encoder
118             children_names = list(self.children_names())
119             return sorted(req_layers, key=children_names.index)[-1]
120         except ValueError as error:
121             layer = str(error).split()[0]
122         raise ValueError(f"Layer {layer} is not part of the multi-layer encoder.")
124     def trim(self, *layer_cfgs):
125         deepest_layer = self._find_deepest_layer(*layer_cfgs)
126         children_names = list(self.children_names())
127         del self[children_names.index(deepest_layer) + 1 :]

The pretrained models the MultiLayerEncoder is based on are usually trained on preprocessed images. In PyTorch all models expect images are normalized by a per-channel mean = (0.485, 0.456, 0.406) and standard deviation (std = (0.229, 0.224, 0.225)). To include this into a, MultiLayerEncoder, we implement this as torch.nn.Module .

139 class Normalize(nn.Module):
140     def __init__(self, mean, std):
141         super().__init__()
142         self.register_buffer("mean", torch.tensor(mean).view(1, -1, 1, 1))
143         self.register_buffer("std", torch.tensor(std).view(1, -1, 1, 1))
145     def forward(self, image):
146         return (image - self.mean) / self.std
149 class TorchNormalize(Normalize):
150     def __init__(self):
151         super().__init__((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

In a last step we need to specify the structure of the MultiLayerEncoder. In this example we use a VGGMultiLayerEncoder based on the VGG19 CNN introduced by Simonyan and Zisserman [SZ2014].

We only include the feature extraction stage (vgg_net.features), i.e. the convolutional stage, since the classifier stage (vgg_net.classifier) only accepts feature maps of a single size.

For our convenience we rename the layers in the same scheme the authors used instead of keeping the consecutive index of a default torch.nn.Sequential. The first layer however is the TorchNormalize as defined above.

168 class VGGMultiLayerEncoder(MultiLayerEncoder):
169     def __init__(self, vgg_net):
170         modules = OrderedDict((("preprocessing", TorchNormalize()),))
172         block = depth = 1
173         for module in vgg_net.features.children():
174             if isinstance(module, nn.Conv2d):
175                 layer = f"conv{block}_{depth}"
176             elif isinstance(module, nn.BatchNorm2d):
177                 layer = f"bn{block}_{depth}"
178             elif isinstance(module, nn.ReLU):
179                 # without inplace=False the encodings of the previous layer would no
180                 # longer be accessible after the ReLU layer is executed
181                 module = nn.ReLU(inplace=False)
182                 layer = f"relu{block}_{depth}"
183                 # each ReLU layer increases the depth of the current block by one
184                 depth += 1
185             elif isinstance(module, nn.MaxPool2d):
186                 layer = f"pool{block}"
187                 # each max pooling layer marks the end of the current block
188                 block += 1
189                 depth = 1
190             else:
191                 msg = f"Type {type(module)} is not part of the VGG architecture."
192                 raise RuntimeError(msg)
194             modules[layer] = module
196         super().__init__(modules)
199 def vgg19_multi_layer_encoder():
200     return VGGMultiLayerEncoder(vgg19(pretrained=True))
203 multi_layer_encoder = vgg19_multi_layer_encoder().to(device)
204 print(multi_layer_encoder)

Perceptual Loss

In order to calculate the perceptual loss, i.e. the optimization criterion, we define a MultiLayerLoss to have a convenient interface. This will be subclassed later by the ContentLoss and StyleLoss.

If called with a sequence of ìnput_encs the MultiLayerLoss should calculate layerwise scores together with the corresponding target_encs. For that a MultiLayerLoss needs the ability to store the target_encs so that they can be reused for every call. The individual layer scores should be averaged by the number of encodings and finally weighted by a score_weight.

To achieve this we subclass torch.nn.Module . The target_encs are stored as buffers, since they are not trainable parameters. The actual functionality has to be defined in calculate_score by a subclass.

226 def mean(sized):
227     return sum(sized) / len(sized)
230 class MultiLayerLoss(nn.Module):
231     def __init__(self, score_weight=1e0):
232         super().__init__()
233         self.score_weight = score_weight
234         self._numel_target_encs = 0
236     def _target_enc_name(self, idx):
237         return f"_target_encs_{idx}"
239     def set_target_encs(self, target_encs):
240         self._numel_target_encs = len(target_encs)
241         for idx, enc in enumerate(target_encs):
242             self.register_buffer(self._target_enc_name(idx), enc.detach())
244     @property
245     def target_encs(self):
246         return tuple(
247             getattr(self, self._target_enc_name(idx))
248             for idx in range(self._numel_target_encs)
249         )
251     def forward(self, input_encs):
252         if len(input_encs) != self._numel_target_encs:
253             msg = (
254                 f"The number of given input encodings and stored target encodings "
255                 f"does not match: {len(input_encs)} != {self._numel_target_encs}"
256             )
257             raise RuntimeError(msg)
259         layer_losses = [
260             self.calculate_score(input, target)
261             for input, target in zip(input_encs, self.target_encs)
262         ]
263         return mean(layer_losses) * self.score_weight
265     def calculate_score(self, input, target):
266         raise NotImplementedError

In this example we use the feature_reconstruction_loss introduced by Mahendran and Vedaldi [MV2015] as ContentLoss as well as the gram_loss introduced by Gatys, Ecker, and Bethge [GEB2016] as StyleLoss.

275 def feature_reconstruction_loss(input, target):
276     return mse_loss(input, target)
279 class ContentLoss(MultiLayerLoss):
280     def calculate_score(self, input, target):
281         return feature_reconstruction_loss(input, target)
284 def channelwise_gram_matrix(x, normalize=True):
285     x = torch.flatten(x, 2)
286     G = torch.bmm(x, x.transpose(1, 2))
287     if normalize:
288         return G / x.size()[-1]
289     else:
290         return G
293 def gram_loss(input, target):
294     return mse_loss(channelwise_gram_matrix(input), channelwise_gram_matrix(target))
297 class StyleLoss(MultiLayerLoss):
298     def calculate_score(self, input, target):
299         return gram_loss(input, target)


Before we can load the content and style image, we need to define some basic I/O utilities.

At import a fake batch dimension is added to the images to be able to pass it through the MultiLayerEncoder without further modification. This dimension is removed again upon export. Furthermore, all images will be resized to size=500 pixels.

313 import_from_pil = transforms.Compose(
314     (
315         transforms.ToTensor(),
316         transforms.Lambda(lambda x: x.unsqueeze(0)),
317         transforms.Lambda(lambda x:,
318     )
319 )
321 export_to_pil = transforms.Compose(
322     (
323         transforms.Lambda(lambda x: x.cpu()),
324         transforms.Lambda(lambda x: x.squeeze(0)),
325         transforms.Lambda(lambda x: x.clamp(0.0, 1.0)),
326         transforms.ToPILImage(),
327     )
328 )
331 def download_image(url):
332     file = os.path.abspath(os.path.basename(url))
333     with open(file, "wb") as fh, urlopen(url) as response:
334         fh.write(
336     return file
339 def read_image(file, size=500):
340     image =
341     image = resize(image, size)
342     return import_from_pil(image)
345 def show_image(image, title=None):
346     _, ax = plt.subplots()
347     ax.axis("off")
348     if title is not None:
349         ax.set_title(title)
351     image = export_to_pil(image)
352     ax.imshow(image)

With the I/O utilities set up, we now download, read, and show the images that will be used in the NST.


The images used in this example are licensed under the permissive Pixabay License .

367 content_url = ""
368 content_file = download_image(content_url)
369 content_image = read_image(content_file)
370 show_image(content_image, title="Content image")
375 style_url = ""
376 style_file = download_image(style_url)
377 style_image = read_image(style_file)
378 show_image(style_image, title="Style image")

Neural Style Transfer

At first we chose the content_layers and style_layers on which the encodings are compared. With them we trim the multi_layer_encoder to remove unused layers that otherwise occupy memory.

Afterwards we calculate the target content and style encodings. The calculation is performed without a gradient since the gradient of the target encodings is not needed for the optimization.

393 content_layers = ("relu4_2",)
394 style_layers = ("relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1")
396 multi_layer_encoder.trim(content_layers, style_layers)
398 with torch.no_grad():
399     target_content_encs = multi_layer_encoder(content_image, content_layers)[0]
400     target_style_encs = multi_layer_encoder(style_image, style_layers)[0]

Next up, we instantiate the ContentLoss and StyleLoss with a corresponding weight. Afterwards we store the previously calculated target encodings.

407 content_weight = 1e0
408 content_loss = ContentLoss(score_weight=content_weight)
409 content_loss.set_target_encs(target_content_encs)
411 style_weight = 1e3
412 style_loss = StyleLoss(score_weight=style_weight)
413 style_loss.set_target_encs(target_style_encs)

We start NST from the content_image since this way it converges quickly.

419 input_image = content_image.clone()
420 show_image(input_image, "Input image")


If you want to start from a white noise image instead use

input_image = torch.rand_like(content_image)

In a last preliminary step we create the optimizer that will be performing the NST. Since we want to adapt the pixels of the input_image directly, we pass it as optimization parameters.

438 optimizer = optim.LBFGS([input_image.requires_grad_(True)], max_iter=1)

Finally we run the NST. The loss calculation has to happen inside a closure since the LBFGS optimizer could need to reevaluate it multiple times per optimization step . This structure is also valid for all other optimizers.

447 num_steps = 500
449 with tqdm(desc="Image optimization", total=num_steps) as progress_bar:
450     for _ in range(num_steps):
452         def closure():
453             optimizer.zero_grad()
455             input_encs = multi_layer_encoder(input_image, content_layers, style_layers)
456             input_content_encs, input_style_encs = input_encs
458             content_score = content_loss(input_content_encs)
459             style_score = style_loss(input_style_encs)
461             perceptual_loss = content_score + style_score
462             perceptual_loss.backward()
464             progress_bar.set_postfix(
465                 loss=f"{float(perceptual_loss):.3e}", refresh=False
466             )
467             progress_bar.update()
469             return perceptual_loss
471         optimizer.step(closure)
473 output_image = input_image.detach()

After the NST we show the resulting image.

478 show_image(output_image, title="Output image")


As hopefully has become clear, an NST requires even in its simplest form quite a lot of utilities and boilerplate code. This makes it hard to maintain and keep bug free as it is easy to lose track of everything.

Judging by the lines of code one could (falsely) conclude that the actual NST is just an appendix. If you feel the same you can stop worrying now: in Neural Style Transfer with pystiche we showcase how to achieve the same result with pystiche.

Total running time of the script: ( 0 minutes 0.000 seconds)

Estimated memory usage: 0 MB

Gallery generated by Sphinx-Gallery