Model optimization

This example showcases how an NST based on model optimization can be performed in pystiche. It closely follows the official PyTorch example which in turn is based on [JAL2016].

We start this example by importing everything we need and setting the device we will be working on.

16 import contextlib
17 import os
18 import time
19 from collections import OrderedDict
20 from os import path
21
22 import torch
23 from torch import hub, nn
24 from torch.nn.functional import interpolate
25 from torch.utils.data import DataLoader
26 from torchvision import transforms
27
28 import pystiche
29 from pystiche import demo, enc, loss, optim
30 from pystiche.data import ImageFolderDataset
31 from pystiche.image import show_image
32 from pystiche.misc import get_device
33
34 print(f"I'm working with pystiche=={pystiche.__version__}")
35
36 device = get_device()
37 print(f"I'm working with {device}")

Transformer

In contrast to image optimization, for model optimization we need to define a transformer that, after it is trained, performs the stylization. In general different architectures are possible ([JAL2016][ULVL2016]). For this example we use an encoder-decoder architecture.

Before we define the transformer, we create some helper modules to reduce the clutter.

In the decoder we need to upsample the image. While it is possible to achieve this with a ConvTranspose2d, it was found that traditional upsampling followed by a standard convolution produces fewer artifacts. Thus, we create an module that wraps torch.nn.functional.interpolate().

60 class Interpolate(nn.Module):
61     def __init__(self, scale_factor=1.0, mode="nearest"):
62         super().__init__()
63         self.scale_factor = scale_factor
64         self.mode = mode
65
66     def forward(self, input):
67         return interpolate(input, scale_factor=self.scale_factor, mode=self.mode,)
68
69     def extra_repr(self):
70         extras = []
71         if self.scale_factor:
72             extras.append(f"scale_factor={self.scale_factor}")
73         if self.mode != "nearest":
74             extras.append(f"mode={self.mode}")
75         return ", ".join(extras)

For the transformer architecture we will be using, we need to define a convolution module with some additional capabilities. In particular, it needs to be able to

  • optionally upsample the input,

  • pad the input in order for the convolution to be size-preserving,

  • optionally normalize the output, and

  • optionally pass the output through an activation function.

Note

Instead of BatchNorm2d we use InstanceNorm2d to normalize the output since it gives better results for NST [UVL2016].

 93 class Conv(nn.Module):
 94     def __init__(
 95         self,
 96         in_channels,
 97         out_channels,
 98         kernel_size,
 99         stride=1,
100         upsample=False,
101         norm=True,
102         activation=True,
103     ):
104         super().__init__()
105         self.upsample = Interpolate(scale_factor=stride) if upsample else None
106         self.pad = nn.ReflectionPad2d(kernel_size // 2)
107         self.conv = nn.Conv2d(
108             in_channels, out_channels, kernel_size, stride=1 if upsample else stride
109         )
110         self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None
111         self.activation = nn.ReLU() if activation else None
112
113     def forward(self, input):
114         if self.upsample:
115             input = self.upsample(input)
116
117         output = self.conv(self.pad(input))
118
119         if self.norm:
120             output = self.norm(output)
121         if self.activation:
122             output = self.activation(output)
123
124         return output

It is common practice to append a few residual blocks after the initial convolutions to the encoder to enable it to learn more descriptive features.

132 class Residual(nn.Module):
133     def __init__(self, channels):
134         super().__init__()
135         self.conv1 = Conv(channels, channels, kernel_size=3)
136         self.conv2 = Conv(channels, channels, kernel_size=3, activation=False)
137
138     def forward(self, input):
139         output = self.conv2(self.conv1(input))
140         return output + input

It can be useful for the training to transform the input into another value range, for example from \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{1}\) to \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{255}\).

148 class FloatToUint8Range(nn.Module):
149     def forward(self, input):
150         return input * 255.0
151
152
153 class Uint8ToFloatRange(nn.Module):
154     def forward(self, input):
155         return input / 255.0

Finally, we can put all pieces together.

Note

You can access this transformer through pystiche.demo.transformer().

166 class Transformer(nn.Module):
167     def __init__(self):
168         super().__init__()
169         self.encoder = nn.Sequential(
170             Conv(3, 32, kernel_size=9),
171             Conv(32, 64, kernel_size=3, stride=2),
172             Conv(64, 128, kernel_size=3, stride=2),
173             Residual(128),
174             Residual(128),
175             Residual(128),
176             Residual(128),
177             Residual(128),
178         )
179         self.decoder = nn.Sequential(
180             Conv(128, 64, kernel_size=3, stride=2, upsample=True),
181             Conv(64, 32, kernel_size=3, stride=2, upsample=True),
182             Conv(32, 3, kernel_size=9, norm=False, activation=False),
183         )
184
185         self.preprocessor = FloatToUint8Range()
186         self.postprocessor = Uint8ToFloatRange()
187
188     def forward(self, input):
189         input = self.preprocessor(input)
190         output = self.decoder(self.encoder(input))
191         return self.postprocessor(output)
192
193
194 transformer = Transformer().to(device)
195 print(transformer)

Perceptual loss

Although model optimization is a different paradigm, the perceptual_loss is the same as for image optimization.

Note

In some implementations, such as the PyTorch example and [JAL2016], one can observe that the gram_matrix(), used as style representation, is not only normalized by the height and width of the feature map, but also by the number of channels. If used together with a mse_loss(), the normalization is performed twice. While this is unintended, it affects the training. In order to keep the other hyper parameters on par with the PyTorch example, we also adopt this change here.

215 multi_layer_encoder = enc.vgg16_multi_layer_encoder()
216
217 content_layer = "relu2_2"
218 content_encoder = multi_layer_encoder.extract_encoder(content_layer)
219 content_weight = 1e5
220 content_loss = loss.FeatureReconstructionLoss(
221     content_encoder, score_weight=content_weight
222 )
223
224
225 class GramOperator(loss.GramLoss):
226     def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
227         repr = super().enc_to_repr(enc)
228         num_channels = repr.size()[1]
229         return repr / num_channels
230
231
232 style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
233 style_weight = 1e10
234 style_loss = loss.MultiLayerEncodingLoss(
235     multi_layer_encoder,
236     style_layers,
237     lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
238     layer_weights="sum",
239     score_weight=style_weight,
240 )
241
242 perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
243 perceptual_loss = perceptual_loss.to(device)
244 print(perceptual_loss)

Training

In a first step we load the style image that will be used to train the transformer.

254 images = demo.images()
255 size = 500
256
257 style_image = images["paint"].read(size=size, device=device)
258 show_image(style_image)

The training of the transformer is performed similar to other models in PyTorch. In every optimization step a batch of content images is drawn from a dataset, which serve as input for the transformer as well as content_image for the perceptual_loss. While the style_image only has to be set once, the content_image has to be reset in every iteration step.

While this can be done with a boilerplate optimization loop, pystiche provides multi_epoch_model_optimization() that handles the above for you.

Note

If the perceptual_loss is a PerceptualLoss, as is the case here, the update of the content_image is performed automatically. If that is not the case or you need more complex update behavior, you need to specify a criterion_update_fn.

Note

If you do not specify an optimizer, the default_model_optimizer(), i.e. Adam is used.

285 def train(*, transformer, root, batch_size, epochs, image_size):
286     if root is None:
287         raise RuntimeError("You forgot to define a root image directory.")
288
289     transform = nn.Sequential(
290         transforms.Resize(image_size), transforms.CenterCrop(image_size),
291     )
292     dataset = ImageFolderDataset(root, transform=transform)
293     image_loader = DataLoader(dataset, batch_size=batch_size)
294
295     perceptual_loss.set_style_image(style_image)
296
297     return optim.multi_epoch_model_optimization(
298         image_loader, transformer.train(), perceptual_loss, epochs=epochs,
299     )

Depending on the dataset and your setup the training can take a couple of hours. To avoid this, we provide transformer weights that were trained with the scheme above.

307 def download():
308     # Unfortunately, torch.hub.load_state_dict_from_url has no option to disable
309     # printing the downloading process. Since this would clutter the output, we
310     # suppress it completely.
311     @contextlib.contextmanager
312     def suppress_output():
313         with open(os.devnull, "w") as devnull:
314             with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(
315                 devnull
316             ):
317                 yield
318
319     with suppress_output():
320         return hub.load_state_dict_from_url(
321             "https://download.pystiche.org/models/example_transformer.pth"
322         )

Note

The weights of the provided transformer were trained with the 2014 training images of the COCO dataset. The training was performed for num_epochs=2 and batch_size=4. Each image was center-cropped to 256 x 256 pixels.

Note

If you want to perform the training yourself, set root to a location of a folder of images.

339 root = None
340 checkpoint = "example_transformer.pth"
341
342 if root is None:
343     state_dict = torch.load(checkpoint) if path.exists(checkpoint) else download()
344     transformer.load_state_dict(state_dict)
345 else:
346     transformer = train(
347         transformer=transformer, root=root, batch_size=4, epochs=2, image_size=256,
348     )
349     state_dict = OrderedDict(
350         [
351             (name, parameter.detach().cpu())
352             for name, parameter in transformer.state_dict().items()
353         ]
354     )
355     torch.save(state_dict, checkpoint)

Neural Style Transfer

In order to perform the NST, we load an image we want to stylize.

364 input_image = images["bird1"].read(size=size, device=device)
365 show_image(input_image)

After the transformer is trained we can now perform an NST with a single forward pass. To do this, the transformer is simply called with the input_image.

372 transformer.eval()
373
374 start = time.time()
375
376 with torch.no_grad():
377     output_image = transformer(input_image)
378
379 stop = time.time()
380
381 show_image(output_image, title="Output image")

Compared to NST via image optimization, the stylization is performed multiple orders of magnitudes faster. Given capable hardware, NST via model optimization enables real-time stylization for example of a video feed.

390 print(f"The stylization took {(stop - start) * 1e3:.0f} milliseconds.")

Total running time of the script: ( 0 minutes 0.000 seconds)

Estimated memory usage: 0 MB

Gallery generated by Sphinx-Gallery