Note
Click here to download the full example code
Model optimization¶
This example showcases how an NST based on model optimization can be performed in
pystiche
. It closely follows the
official PyTorch example
which in turn is based on [JAL2016].
We start this example by importing everything we need and setting the device we will be working on.
16 import contextlib
17 import os
18 import time
19 from collections import OrderedDict
20 from os import path
21
22 import torch
23 from torch import hub, nn
24 from torch.nn.functional import interpolate
25 from torch.utils.data import DataLoader
26 from torchvision import transforms
27
28 import pystiche
29 from pystiche import demo, enc, loss, optim
30 from pystiche.data import ImageFolderDataset
31 from pystiche.image import show_image
32 from pystiche.misc import get_device
33
34 print(f"I'm working with pystiche=={pystiche.__version__}")
35
36 device = get_device()
37 print(f"I'm working with {device}")
Transformer¶
In contrast to image optimization, for model optimization we need to define a transformer that, after it is trained, performs the stylization. In general different architectures are possible ([JAL2016][ULVL2016]). For this example we use an encoder-decoder architecture.
Before we define the transformer, we create some helper modules to reduce the clutter.
In the decoder we need to upsample the image. While it is possible to achieve this
with a ConvTranspose2d
, it was found that traditional upsampling
followed by a standard convolution
produces fewer artifacts. Thus,
we create an module that wraps torch.nn.functional.interpolate()
.
60 class Interpolate(nn.Module):
61 def __init__(self, scale_factor=1.0, mode="nearest"):
62 super().__init__()
63 self.scale_factor = scale_factor
64 self.mode = mode
65
66 def forward(self, input):
67 return interpolate(input, scale_factor=self.scale_factor, mode=self.mode,)
68
69 def extra_repr(self):
70 extras = []
71 if self.scale_factor:
72 extras.append(f"scale_factor={self.scale_factor}")
73 if self.mode != "nearest":
74 extras.append(f"mode={self.mode}")
75 return ", ".join(extras)
For the transformer architecture we will be using, we need to define a convolution module with some additional capabilities. In particular, it needs to be able to
optionally upsample the input,
pad the input in order for the convolution to be size-preserving,
optionally normalize the output, and
optionally pass the output through an activation function.
Note
Instead of BatchNorm2d
we use InstanceNorm2d
to normalize the output since it gives better results for NST [UVL2016].
93 class Conv(nn.Module):
94 def __init__(
95 self,
96 in_channels,
97 out_channels,
98 kernel_size,
99 stride=1,
100 upsample=False,
101 norm=True,
102 activation=True,
103 ):
104 super().__init__()
105 self.upsample = Interpolate(scale_factor=stride) if upsample else None
106 self.pad = nn.ReflectionPad2d(kernel_size // 2)
107 self.conv = nn.Conv2d(
108 in_channels, out_channels, kernel_size, stride=1 if upsample else stride
109 )
110 self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None
111 self.activation = nn.ReLU() if activation else None
112
113 def forward(self, input):
114 if self.upsample:
115 input = self.upsample(input)
116
117 output = self.conv(self.pad(input))
118
119 if self.norm:
120 output = self.norm(output)
121 if self.activation:
122 output = self.activation(output)
123
124 return output
It is common practice to append a few residual blocks after the initial convolutions to the encoder to enable it to learn more descriptive features.
132 class Residual(nn.Module):
133 def __init__(self, channels):
134 super().__init__()
135 self.conv1 = Conv(channels, channels, kernel_size=3)
136 self.conv2 = Conv(channels, channels, kernel_size=3, activation=False)
137
138 def forward(self, input):
139 output = self.conv2(self.conv1(input))
140 return output + input
It can be useful for the training to transform the input into another value range, for example from \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{1}\) to \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{255}\).
Finally, we can put all pieces together.
Note
You can access this transformer through pystiche.demo.transformer()
.
166 class Transformer(nn.Module):
167 def __init__(self):
168 super().__init__()
169 self.encoder = nn.Sequential(
170 Conv(3, 32, kernel_size=9),
171 Conv(32, 64, kernel_size=3, stride=2),
172 Conv(64, 128, kernel_size=3, stride=2),
173 Residual(128),
174 Residual(128),
175 Residual(128),
176 Residual(128),
177 Residual(128),
178 )
179 self.decoder = nn.Sequential(
180 Conv(128, 64, kernel_size=3, stride=2, upsample=True),
181 Conv(64, 32, kernel_size=3, stride=2, upsample=True),
182 Conv(32, 3, kernel_size=9, norm=False, activation=False),
183 )
184
185 self.preprocessor = FloatToUint8Range()
186 self.postprocessor = Uint8ToFloatRange()
187
188 def forward(self, input):
189 input = self.preprocessor(input)
190 output = self.decoder(self.encoder(input))
191 return self.postprocessor(output)
192
193
194 transformer = Transformer().to(device)
195 print(transformer)
Perceptual loss¶
Although model optimization is a different paradigm, the perceptual_loss
is the
same as for image optimization.
Note
In some implementations, such as the PyTorch example and [JAL2016], one can
observe that the gram_matrix()
, used as style representation, is not
only normalized by the height and width of the feature map, but also by the number
of channels. If used together with a mse_loss()
, the
normalization is performed twice. While this is unintended, it affects the training.
In order to keep the other hyper parameters on par with the PyTorch example, we also
adopt this change here.
215 multi_layer_encoder = enc.vgg16_multi_layer_encoder()
216
217 content_layer = "relu2_2"
218 content_encoder = multi_layer_encoder.extract_encoder(content_layer)
219 content_weight = 1e5
220 content_loss = loss.FeatureReconstructionLoss(
221 content_encoder, score_weight=content_weight
222 )
223
224
225 class GramOperator(loss.GramLoss):
226 def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
227 repr = super().enc_to_repr(enc)
228 num_channels = repr.size()[1]
229 return repr / num_channels
230
231
232 style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
233 style_weight = 1e10
234 style_loss = loss.MultiLayerEncodingLoss(
235 multi_layer_encoder,
236 style_layers,
237 lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
238 layer_weights="sum",
239 score_weight=style_weight,
240 )
241
242 perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
243 perceptual_loss = perceptual_loss.to(device)
244 print(perceptual_loss)
Training¶
In a first step we load the style image that will be used to train the
transformer
.
254 images = demo.images()
255 size = 500
256
257 style_image = images["paint"].read(size=size, device=device)
258 show_image(style_image)
The training of the transformer
is performed similar to other models in PyTorch.
In every optimization step a batch of content images is drawn from a dataset, which
serve as input for the transformer as well as content_image
for the
perceptual_loss
. While the style_image
only has to be set once, the
content_image
has to be reset in every iteration step.
While this can be done with a boilerplate optimization loop, pystiche
provides
multi_epoch_model_optimization()
that handles the above for you.
Note
If the perceptual_loss
is a PerceptualLoss
, as is the
case here, the update of the content_image
is performed automatically. If that
is not the case or you need more complex update behavior, you need to specify a
criterion_update_fn
.
Note
If you do not specify an optimizer
, the
default_model_optimizer()
, i.e. Adam
is
used.
285 def train(*, transformer, root, batch_size, epochs, image_size):
286 if root is None:
287 raise RuntimeError("You forgot to define a root image directory.")
288
289 transform = nn.Sequential(
290 transforms.Resize(image_size), transforms.CenterCrop(image_size),
291 )
292 dataset = ImageFolderDataset(root, transform=transform)
293 image_loader = DataLoader(dataset, batch_size=batch_size)
294
295 perceptual_loss.set_style_image(style_image)
296
297 return optim.multi_epoch_model_optimization(
298 image_loader, transformer.train(), perceptual_loss, epochs=epochs,
299 )
Depending on the dataset and your setup the training can take a couple of hours. To avoid this, we provide transformer weights that were trained with the scheme above.
307 def download():
308 # Unfortunately, torch.hub.load_state_dict_from_url has no option to disable
309 # printing the downloading process. Since this would clutter the output, we
310 # suppress it completely.
311 @contextlib.contextmanager
312 def suppress_output():
313 with open(os.devnull, "w") as devnull:
314 with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(
315 devnull
316 ):
317 yield
318
319 with suppress_output():
320 return hub.load_state_dict_from_url(
321 "https://download.pystiche.org/models/example_transformer.pth"
322 )
Note
The weights of the provided transformer were trained with the
2014 training images of the
COCO dataset. The training was performed for
num_epochs=2
and batch_size=4
. Each image was center-cropped to
256 x 256
pixels.
Note
If you want to perform the training yourself, set root
to a location
of a folder of images.
339 root = None
340 checkpoint = "example_transformer.pth"
341
342 if root is None:
343 state_dict = torch.load(checkpoint) if path.exists(checkpoint) else download()
344 transformer.load_state_dict(state_dict)
345 else:
346 transformer = train(
347 transformer=transformer, root=root, batch_size=4, epochs=2, image_size=256,
348 )
349 state_dict = OrderedDict(
350 [
351 (name, parameter.detach().cpu())
352 for name, parameter in transformer.state_dict().items()
353 ]
354 )
355 torch.save(state_dict, checkpoint)
Neural Style Transfer¶
In order to perform the NST, we load an image we want to stylize.
364 input_image = images["bird1"].read(size=size, device=device)
365 show_image(input_image)
After the transformer is trained we can now perform an NST with a single forward pass.
To do this, the transformer
is simply called with the input_image
.
372 transformer.eval()
373
374 start = time.time()
375
376 with torch.no_grad():
377 output_image = transformer(input_image)
378
379 stop = time.time()
380
381 show_image(output_image, title="Output image")
Compared to NST via image optimization, the stylization is performed multiple orders of magnitudes faster. Given capable hardware, NST via model optimization enables real-time stylization for example of a video feed.
390 print(f"The stylization took {(stop - start) * 1e3:.0f} milliseconds.")
Total running time of the script: ( 0 minutes 0.000 seconds)
Estimated memory usage: 0 MB