Note
Click here to download the full example code
Multi-layer Encoder¶
This example showcases the pystiche.enc.MultiLayerEncoder
.
We start this example by importing everything we need.
12 import itertools
13 import time
14 from collections import OrderedDict
15 from math import floor, log10
16
17 import torch
18 from torch import nn
19 from torchvision import models
20
21 import pystiche
22 from pystiche import enc
23
24 print(f"I'm working with pystiche=={pystiche.__version__}")
In a second preliminary step we define some helper functions to ease the performance analysis later on.
31 SI_PREFIXES = {0: "", -3: "m", -6: "µ"}
32
33
34 def timeit(fn, times=10, cleanup=None):
35 total = 0.0
36 for _ in range(times):
37 start = time.time()
38 fn()
39 stop = time.time()
40 total += stop - start
41 if cleanup:
42 cleanup()
43 return total / times
44
45
46 def feng(num, unit, digits=3):
47 exp = int(floor(log10(num)))
48 exp -= exp % 3
49 sig = num * 10 ** -exp
50 prec = digits - len(str(int(sig)))
51 return f"{sig:.{prec}f} {SI_PREFIXES[exp]}{unit}"
52
53
54 def fsecs(seconds):
55 return feng(seconds, "s")
56
57
58 def ftimeit(fn, msg="The execution took {seconds}.", **kwargs):
59 return msg.format(seconds=fsecs(timeit(fn, **kwargs)))
60
61
62 def fdifftimeit(seq_fn, mle_fn, **kwargs):
63 time_seq = timeit(seq_fn, **kwargs)
64 time_mle = timeit(mle_fn, **kwargs)
65
66 abs_diff = time_mle - time_seq
67 rel_diff = abs_diff / time_seq
68
69 if abs_diff >= 0:
70 return (
71 f"Encoding the input with the enc.MultiLayerEncoder was "
72 f"{fsecs(abs_diff)} ({rel_diff:.0%}) slower."
73 )
74 else:
75 return "\n".join(
76 (
77 "Due to the very rough timing method used here, ",
78 "we detected a case where the encoding with the enc.MultiLayerEncoder ",
79 "was actually faster than the boiler-plate nn.Sequential. ",
80 "Since the enc.MultiLayerEncoder has some overhead, ",
81 "this is a measuring error. ",
82 "Still, this serves as indicator that the overhead is small enough, ",
83 "to be well in the measuring tolerance.",
84 )
85 )
Next up, we define the device we will be testing on as well as the input dimensions.
Note
We encourage the user to play with these parameters and see how the results change. In order to do that, you can use the download buttons at the bottom of this page.
96 device = torch.device("cpu")
97
98 batch_size = 32
99 num_channels = 3
100 height = width = 512
101
102 input = torch.rand((batch_size, num_channels, height, width), device=device)
As a toy example to showcase the MultiLayerEncoder
capabilities, we will use a CNN with three layers.
109 conv = nn.Conv2d(num_channels, num_channels, 3, padding=1)
110 relu = nn.ReLU(inplace=False)
111 pool = nn.MaxPool2d(2)
112
113 modules = [("conv", conv), ("relu", relu), ("pool", pool)]
114
115 seq = nn.Sequential(OrderedDict(modules)).to(device)
116 mle = enc.MultiLayerEncoder(modules).to(device)
117 print(mle)
Before we dive into the additional functionalities of the
MultiLayerEncoder
we perform a smoke test and assert that it
indeed does the same as an torch.nn.Sequential
with the same layers.
125 assert torch.allclose(mle(input), seq(input))
126 print(fdifftimeit(lambda: seq(input), lambda: mle(input)))
As we saw, the MultiLayerEncoder
produces the same output as
an torch.nn.Sequential
but is slower. In the following we will learn what
other functionalities a MultiLayerEncoder
has to offer that
justify this overhead.
Intermediate feature maps¶
By calling the multi-layer encoder with a layer name in addition to the input, the
intermediate layers of the MultiLayerEncoder
can be accessed.
This is helpful if one needs the feature maps from different layers of a model, as is
often the case during an NST.
143 assert torch.allclose(mle(input, "conv"), conv(input))
144 assert torch.allclose(mle(input, "relu"), relu(conv(input)))
145 assert torch.allclose(mle(input, "pool"), pool(relu(conv(input))))
For convenience, one can extract a pystiche.enc.SingleLayerEncoder
as an
interface to the multi-layer encoder for a specific layer.
152 sle = mle.extract_encoder("conv")
153 assert torch.allclose(sle(input), conv(input))
Caching¶
If the access intermediate feature maps is necessary, as is usually the case in an NST, it is important to only compute every layer once.
A MultiLayerEncoder()
enables this functionality by caching
already computed feature maps. Thus, after an input is cached, retrieving it is a
constant time lookup
In order to enable caching for a layer, it has to be registered first.
Note
The internal cache will be automatically cleared during the backward pass. Since we
don’t perform that here, we need to clear it manually by calling
clear_cache()
Note
extract_encoder()
automatically registers the
layer for caching.
180 shallow_layers = ("conv", "relu")
181 for layer in shallow_layers:
182 mle.register_layer(layer)
183
184 mle(input)
185
186 for layer in shallow_layers:
187 print(
188 ftimeit(
189 lambda: mle(input, layer),
190 (
191 f"After the forward pass was completed once for the input, "
192 f"extracting the encoding of the intermediate layer '{layer}' "
193 f"took {{seconds}}."
194 ),
195 )
196 )
197
198 mle.clear_cache()
Due to this caching, it doesn’t matter in which order the feature maps are requested:
If a shallow layer is requested before a deeper one, the encoding is later resumed from the feature map of the shallow layer.
If a deep layer is requested before a more shallow one, the feature map of the shallow one is cached while computing the deep layer.
210 def fn(layers):
211 for layer in layers:
212 mle(input, layer)
213
214
215 for permutation in itertools.permutations(("conv", "relu", "pool")):
216 order = f"""'{"' -> '".join(permutation)}'"""
217 print(
218 ftimeit(
219 lambda: fn(permutation),
220 f"The encoding of layers {order} took {{seconds}}.",
221 cleanup=mle.clear_cache,
222 )
223 )
Real-world example¶
Up to this point we used a toy example to demonstrate the capabilities of a
MultiLayerEncoder
. In addition to the boiler-plate
MultiLayerEncoder
, pystiche
has builtin implementations of
some well-known CNN architectures that are commonly used in NST papers.
Note
By default, vgg19_multi_layer_encoder()
loads weights provided
by torchvision
. We disable this here since we load the randomly initilaized
weights of the torchvision
model to enable a comparison.
Note
By default, vgg19_multi_layer_encoder()
adds an
internal_preprocessing
so that the user can simply pass the image as is,
without worrying about it. We disable this here to enable a comparison.
Note
By default, vgg19_multi_layer_encoder()
disallows in-place
operations since after they are carried out, the previous encoding is no longer
accessible. In order to enable a fair performance comparison, we allow them here,
since they are also used in vgg19()
.
Note
The fully connected stage of the original VGG19 architecture requires the input to
be exactly 224 pixels wide and high [SZ2014]. Since this requirement can
usually not be met in an NST, the builtin multi-layer encoder only comprises the
size invariant convolutional stage. Thus, we only use vgg19().features
to
enable a comparison.
262 seq = models.vgg19()
263 mle = enc.vgg19_multi_layer_encoder(
264 pretrained=False, internal_preprocessing=False, allow_inplace=True
265 )
266 mle.load_state_dict(seq.state_dict())
267
268 input = torch.rand((4, 3, 256, 256), device=device)
269
270 assert torch.allclose(mle(input), seq.features(input))
271 print(fdifftimeit(lambda: seq.features(input), lambda: mle(input)))
Total running time of the script: ( 0 minutes 0.000 seconds)
Estimated memory usage: 0 MB