Multi-layer Encoder

This example showcases the pystiche.enc.MultiLayerEncoder.

We start this example by importing everything we need.

12 import itertools
13 import time
14 from collections import OrderedDict
15 from math import floor, log10
16
17 import torch
18 from torch import nn
19 from torchvision import models
20
21 import pystiche
22 from pystiche import enc
23
24 print(f"I'm working with pystiche=={pystiche.__version__}")

Out:

I'm working with pystiche==1.1.0.dev44+gd9e3fd8

In a second preliminary step we define some helper functions to ease the performance analysis later on.

31 SI_PREFIXES = {0: "", -3: "m", -6: "µ"}
32
33
34 def timeit(fn, times=10, cleanup=None):
35     total = 0.0
36     for _ in range(times):
37         start = time.time()
38         fn()
39         stop = time.time()
40         total += stop - start
41         if cleanup:
42             cleanup()
43     return total / times
44
45
46 def feng(num, unit, digits=3):
47     exp = int(floor(log10(num)))
48     exp -= exp % 3
49     sig = num * 10 ** -exp
50     prec = digits - len(str(int(sig)))
51     return f"{sig:.{prec}f} {SI_PREFIXES[exp]}{unit}"
52
53
54 def fsecs(seconds):
55     return feng(seconds, "s")
56
57
58 def ftimeit(fn, msg="The execution took {seconds}.", **kwargs):
59     return msg.format(seconds=fsecs(timeit(fn, **kwargs)))
60
61
62 def fdifftimeit(seq_fn, mle_fn, **kwargs):
63     time_seq = timeit(seq_fn, **kwargs)
64     time_mle = timeit(mle_fn, **kwargs)
65
66     abs_diff = time_mle - time_seq
67     rel_diff = abs_diff / time_seq
68
69     if abs_diff >= 0:
70         return (
71             f"Encoding the input with the enc.MultiLayerEncoder was "
72             f"{fsecs(abs_diff)} ({rel_diff:.0%}) slower."
73         )
74     else:
75         return "\n".join(
76             (
77                 "Due to the very rough timing method used here, ",
78                 "we detected a case where the encoding with the enc.MultiLayerEncoder ",
79                 "was actually faster than the boiler-plate nn.Sequential. ",
80                 "Since the enc.MultiLayerEncoder has some overhead, ",
81                 "this is a measuring error. ",
82                 "Still, this serves as indicator that the overhead is small enough, ",
83                 "to be well in the measuring tolerance.",
84             )
85         )

Next up, we define the device we will be testing on as well as the input dimensions.

Note

We encourage the user to play with these parameters and see how the results change. In order to do that, you can use the download buttons at the bottom of this page.

 96 device = torch.device("cpu")
 97
 98 batch_size = 32
 99 num_channels = 3
100 height = width = 512
101
102 input = torch.rand((batch_size, num_channels, height, width), device=device)

As a toy example to showcase the MultiLayerEncoder capabilities, we will use a CNN with three layers.

109 conv = nn.Conv2d(num_channels, num_channels, 3, padding=1)
110 relu = nn.ReLU(inplace=False)
111 pool = nn.MaxPool2d(2)
112
113 modules = [("conv", conv), ("relu", relu), ("pool", pool)]
114
115 seq = nn.Sequential(OrderedDict(modules)).to(device)
116 mle = enc.MultiLayerEncoder(modules).to(device)
117 print(mle)

Out:

MultiLayerEncoder(
  (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

Before we dive into the additional functionalities of the MultiLayerEncoder we perform a smoke test and assert that it indeed does the same as an torch.nn.Sequential with the same layers.

125 assert torch.allclose(mle(input), seq(input))
126 print(fdifftimeit(lambda: seq(input), lambda: mle(input)))

Out:

Due to the very rough timing method used here,
we detected a case where the encoding with the enc.MultiLayerEncoder
was actually faster than the boiler-plate nn.Sequential.
Since the enc.MultiLayerEncoder has some overhead,
this is a measuring error.
Still, this serves as indicator that the overhead is small enough,
to be well in the measuring tolerance.

As we saw, the MultiLayerEncoder produces the same output as an torch.nn.Sequential but is slower. In the following we will learn what other functionalities a MultiLayerEncoder has to offer that justify this overhead.

Intermediate feature maps

By calling the multi-layer encoder with a layer name in addition to the input, the intermediate layers of the MultiLayerEncoder can be accessed. This is helpful if one needs the feature maps from different layers of a model, as is often the case during an NST.

143 assert torch.allclose(mle(input, "conv"), conv(input))
144 assert torch.allclose(mle(input, "relu"), relu(conv(input)))
145 assert torch.allclose(mle(input, "pool"), pool(relu(conv(input))))

For convenience, one can extract a pystiche.enc.SingleLayerEncoder as an interface to the multi-layer encoder for a specific layer.

152 sle = mle.extract_encoder("conv")
153 assert torch.allclose(sle(input), conv(input))

Caching

If the access intermediate feature maps is necessary, as is usually the case in an NST, it is important to only compute every layer once.

A MultiLayerEncoder() enables this functionality by caching already computed feature maps. Thus, after an input is cached, retrieving it is a constant time lookup

In order to enable caching for a layer, it has to be registered first.

Note

The internal cache will be automatically cleared during the backward pass. Since we don’t perform that here, we need to clear it manually by calling clear_cache()

Note

extract_encoder() automatically registers the layer for caching.

180 shallow_layers = ("conv", "relu")
181 for layer in shallow_layers:
182     mle.register_layer(layer)
183
184 mle(input)
185
186 for layer in shallow_layers:
187     print(
188         ftimeit(
189             lambda: mle(input, layer),
190             (
191                 f"After the forward pass was completed once for the input, "
192                 f"extracting the encoding of the intermediate layer '{layer}' "
193                 f"took {{seconds}}."
194             ),
195         )
196     )
197
198 mle.clear_cache()

Out:

After the forward pass was completed once for the input, extracting the encoding of the intermediate layer 'conv' took 6.70 µs.
After the forward pass was completed once for the input, extracting the encoding of the intermediate layer 'relu' took 3.36 µs.

Due to this caching, it doesn’t matter in which order the feature maps are requested:

  1. If a shallow layer is requested before a deeper one, the encoding is later resumed from the feature map of the shallow layer.

  2. If a deep layer is requested before a more shallow one, the feature map of the shallow one is cached while computing the deep layer.

210 def fn(layers):
211     for layer in layers:
212         mle(input, layer)
213
214
215 for permutation in itertools.permutations(("conv", "relu", "pool")):
216     order = f"""'{"' -> '".join(permutation)}'"""
217     print(
218         ftimeit(
219             lambda: fn(permutation),
220             f"The encoding of layers {order} took {{seconds}}.",
221             cleanup=mle.clear_cache,
222         )
223     )

Out:

The encoding of layers 'conv' -> 'relu' -> 'pool' took 360 ms.
The encoding of layers 'conv' -> 'pool' -> 'relu' took 327 ms.
The encoding of layers 'relu' -> 'conv' -> 'pool' took 360 ms.
The encoding of layers 'relu' -> 'pool' -> 'conv' took 360 ms.
The encoding of layers 'pool' -> 'conv' -> 'relu' took 328 ms.
The encoding of layers 'pool' -> 'relu' -> 'conv' took 327 ms.

Real-world example

Up to this point we used a toy example to demonstrate the capabilities of a MultiLayerEncoder. In addition to the boiler-plate MultiLayerEncoder, pystiche has builtin implementations of some well-known CNN architectures that are commonly used in NST papers.

Note

By default, vgg19_multi_layer_encoder() loads weights provided by torchvision. We disable this here since we load the randomly initilaized weights of the torchvision model to enable a comparison.

Note

By default, vgg19_multi_layer_encoder() adds an internal_preprocessing so that the user can simply pass the image as is, without worrying about it. We disable this here to enable a comparison.

Note

By default, vgg19_multi_layer_encoder() disallows in-place operations since after they are carried out, the previous encoding is no longer accessible. In order to enable a fair performance comparison, we allow them here, since they are also used in vgg19().

Note

The fully connected stage of the original VGG19 architecture requires the input to be exactly 224 pixels wide and high [SZ14]. Since this requirement can usually not be met in an NST, the builtin multi-layer encoder only comprises the size invariant convolutional stage. Thus, we only use vgg19().features to enable a comparison.

262 seq = models.vgg19()
263 mle = enc.vgg19_multi_layer_encoder(
264     pretrained=False, internal_preprocessing=False, allow_inplace=True
265 )
266 mle.load_state_dict(seq.state_dict())
267
268 input = torch.rand((4, 3, 256, 256), device=device)
269
270 assert torch.allclose(mle(input), seq.features(input))
271 print(fdifftimeit(lambda: seq.features(input), lambda: mle(input)))

Out:

Due to the very rough timing method used here,
we detected a case where the encoding with the enc.MultiLayerEncoder
was actually faster than the boiler-plate nn.Sequential.
Since the enc.MultiLayerEncoder has some overhead,
this is a measuring error.
Still, this serves as indicator that the overhead is small enough,
to be well in the measuring tolerance.

Total running time of the script: ( 0 minutes 55.868 seconds)

Estimated memory usage: 0 MB

Gallery generated by Sphinx-Gallery