Multi-layer Encoder

This example showcases the pystiche.enc.MultiLayerEncoder.

We start this example by importing everything we need.

12 import itertools
13 import time
14 from collections import OrderedDict
15 from math import floor, log10
17 import torch
18 from torch import nn
19 from torchvision import models
21 import pystiche
22 from pystiche import enc
24 print(f"I'm working with pystiche=={pystiche.__version__}")

In a second preliminary step we define some helper functions to ease the performance analysis later on.

31 SI_PREFIXES = {0: "", -3: "m", -6: "µ"}
34 def timeit(fn, times=10, cleanup=None):
35     total = 0.0
36     for _ in range(times):
37         start = time.time()
38         fn()
39         stop = time.time()
40         total += stop - start
41         if cleanup:
42             cleanup()
43     return total / times
46 def feng(num, unit, digits=3):
47     exp = int(floor(log10(num)))
48     exp -= exp % 3
49     sig = num * 10 ** -exp
50     prec = digits - len(str(int(sig)))
51     return f"{sig:.{prec}f} {SI_PREFIXES[exp]}{unit}"
54 def fsecs(seconds):
55     return feng(seconds, "s")
58 def ftimeit(fn, msg="The execution took {seconds}.", **kwargs):
59     return msg.format(seconds=fsecs(timeit(fn, **kwargs)))
62 def fdifftimeit(seq_fn, mle_fn, **kwargs):
63     time_seq = timeit(seq_fn, **kwargs)
64     time_mle = timeit(mle_fn, **kwargs)
66     abs_diff = time_mle - time_seq
67     rel_diff = abs_diff / time_seq
69     if abs_diff >= 0:
70         return (
71             f"Encoding the input with the enc.MultiLayerEncoder was "
72             f"{fsecs(abs_diff)} ({rel_diff:.0%}) slower."
73         )
74     else:
75         return "\n".join(
76             (
77                 "Due to the very rough timing method used here, ",
78                 "we detected a case where the encoding with the enc.MultiLayerEncoder ",
79                 "was actually faster than the boiler-plate nn.Sequential. ",
80                 "Since the enc.MultiLayerEncoder has some overhead, ",
81                 "this is a measuring error. ",
82                 "Still, this serves as indicator that the overhead is small enough, ",
83                 "to be well in the measuring tolerance.",
84             )
85         )

Next up, we define the device we will be testing on as well as the input dimensions.


We encourage the user to play with these parameters and see how the results change. In order to do that, you can use the download buttons at the bottom of this page.

 96 device = torch.device("cpu")
 98 batch_size = 32
 99 num_channels = 3
100 height = width = 512
102 input = torch.rand((batch_size, num_channels, height, width), device=device)

As a toy example to showcase the MultiLayerEncoder capabilities, we will use a CNN with three layers.

109 conv = nn.Conv2d(num_channels, num_channels, 3, padding=1)
110 relu = nn.ReLU(inplace=False)
111 pool = nn.MaxPool2d(2)
113 modules = [("conv", conv), ("relu", relu), ("pool", pool)]
115 seq = nn.Sequential(OrderedDict(modules)).to(device)
116 mle = enc.MultiLayerEncoder(modules).to(device)
117 print(mle)

Before we dive into the additional functionalities of the MultiLayerEncoder we perform a smoke test and assert that it indeed does the same as an torch.nn.Sequential with the same layers.

125 assert torch.allclose(mle(input), seq(input))
126 print(fdifftimeit(lambda: seq(input), lambda: mle(input)))

As we saw, the MultiLayerEncoder produces the same output as an torch.nn.Sequential but is slower. In the following we will learn what other functionalities a MultiLayerEncoder has to offer that justify this overhead.

Intermediate feature maps

By calling the multi-layer encoder with a layer name in addition to the input, the intermediate layers of the MultiLayerEncoder can be accessed. This is helpful if one needs the feature maps from different layers of a model, as is often the case during an NST.

143 assert torch.allclose(mle(input, "conv"), conv(input))
144 assert torch.allclose(mle(input, "relu"), relu(conv(input)))
145 assert torch.allclose(mle(input, "pool"), pool(relu(conv(input))))

For convenience, one can extract a pystiche.enc.SingleLayerEncoder as an interface to the multi-layer encoder for a specific layer.

152 sle = mle.extract_encoder("conv")
153 assert torch.allclose(sle(input), conv(input))


If the access intermediate feature maps is necessary, as is usually the case in an NST, it is important to only compute every layer once.

A MultiLayerEncoder() enables this functionality by caching already computed feature maps. Thus, after an input is cached, retrieving it is a constant time lookup

In order to enable caching for a layer, it has to be registered first.


The internal cache will be automatically cleared during the backward pass. Since we don’t perform that here, we need to clear it manually by calling clear_cache()


extract_encoder() automatically registers the layer for caching.

180 shallow_layers = ("conv", "relu")
181 for layer in shallow_layers:
182     mle.register_layer(layer)
184 mle(input)
186 for layer in shallow_layers:
187     print(
188         ftimeit(
189             lambda: mle(input, layer),
190             (
191                 f"After the forward pass was completed once for the input, "
192                 f"extracting the encoding of the intermediate layer '{layer}' "
193                 f"took {{seconds}}."
194             ),
195         )
196     )
198 mle.clear_cache()

Due to this caching, it doesn’t matter in which order the feature maps are requested:

  1. If a shallow layer is requested before a deeper one, the encoding is later resumed from the feature map of the shallow layer.

  2. If a deep layer is requested before a more shallow one, the feature map of the shallow one is cached while computing the deep layer.

210 def fn(layers):
211     for layer in layers:
212         mle(input, layer)
215 for permutation in itertools.permutations(("conv", "relu", "pool")):
216     order = f"""'{"' -> '".join(permutation)}'"""
217     print(
218         ftimeit(
219             lambda: fn(permutation),
220             f"The encoding of layers {order} took {{seconds}}.",
221             cleanup=mle.clear_cache,
222         )
223     )

Real-world example

Up to this point we used a toy example to demonstrate the capabilities of a MultiLayerEncoder. In addition to the boiler-plate MultiLayerEncoder, pystiche has builtin implementations of some well-known CNN architectures that are commonly used in NST papers.


By default, vgg19_multi_layer_encoder() loads weights provided by torchvision. We disable this here since we load the randomly initilaized weights of the torchvision model to enable a comparison.


By default, vgg19_multi_layer_encoder() adds an internal_preprocessing so that the user can simply pass the image as is, without worrying about it. We disable this here to enable a comparison.


By default, vgg19_multi_layer_encoder() disallows in-place operations since after they are carried out, the previous encoding is no longer accessible. In order to enable a fair performance comparison, we allow them here, since they are also used in vgg19().


The fully connected stage of the original VGG19 architecture requires the input to be exactly 224 pixels wide and high [SZ2014]. Since this requirement can usually not be met in an NST, the builtin multi-layer encoder only comprises the size invariant convolutional stage. Thus, we only use vgg19().features to enable a comparison.

262 seq = models.vgg19()
263 mle = enc.vgg19_multi_layer_encoder(
264     pretrained=False, internal_preprocessing=False, allow_inplace=True
265 )
266 mle.load_state_dict(seq.state_dict())
268 input = torch.rand((4, 3, 256, 256), device=device)
270 assert torch.allclose(mle(input), seq.features(input))
271 print(fdifftimeit(lambda: seq.features(input), lambda: mle(input)))

Total running time of the script: ( 0 minutes 0.000 seconds)

Estimated memory usage: 0 MB

Gallery generated by Sphinx-Gallery