Model optimization

This example showcases how an NST based on model optimization can be performed in pystiche. It closely follows the official PyTorch example which in turn is based on [JAL16].

We start this example by importing everything we need and setting the device we will be working on.

16 import contextlib
17 import os
18 import time
19 from collections import OrderedDict
20 from os import path
21
22 import torch
23 from torch import hub, nn
24 from torch.nn.functional import interpolate
25
26 import pystiche
27 from pystiche import demo, enc, loss, optim
28 from pystiche.image import show_image
29 from pystiche.misc import get_device
30
31 print(f"I'm working with pystiche=={pystiche.__version__}")
32
33 device = get_device()
34 print(f"I'm working with {device}")

Out:

I'm working with pystiche==1.0.1
I'm working with cuda

Transformer

In contrast to image optimization, for model optimization we need to define a transformer that, after it is trained, performs the stylization. In general different architectures are possible ([JAL16, ULVL16]). For this example we use an encoder-decoder architecture.

Before we define the transformer, we create some helper modules to reduce the clutter.

In the decoder we need to upsample the image. While it is possible to achieve this with a ConvTranspose2d, it was found that traditional upsampling followed by a standard convolution produces fewer artifacts. Thus, we create an module that wraps interpolate().

57 class Interpolate(nn.Module):
58     def __init__(self, scale_factor=1.0, mode="nearest"):
59         super().__init__()
60         self.scale_factor = scale_factor
61         self.mode = mode
62
63     def forward(self, input):
64         return interpolate(input, scale_factor=self.scale_factor, mode=self.mode,)
65
66     def extra_repr(self):
67         extras = []
68         if self.scale_factor:
69             extras.append(f"scale_factor={self.scale_factor}")
70         if self.mode != "nearest":
71             extras.append(f"mode={self.mode}")
72         return ", ".join(extras)

For the transformer architecture we will be using, we need to define a convolution module with some additional capabilities. In particular, it needs to be able to - optionally upsample the input, - pad the input in order for the convolution to be size-preserving, - optionally normalize the output, and - optionally pass the output through an activation function.

Note

Instead of BatchNorm2d we use InstanceNorm2d to normalize the output since it gives better results for NST [UVL16].

 89 class Conv(nn.Module):
 90     def __init__(
 91         self,
 92         in_channels,
 93         out_channels,
 94         kernel_size,
 95         stride=1,
 96         upsample=False,
 97         norm=True,
 98         activation=True,
 99     ):
100         super().__init__()
101         self.upsample = Interpolate(scale_factor=stride) if upsample else None
102         self.pad = nn.ReflectionPad2d(kernel_size // 2)
103         self.conv = nn.Conv2d(
104             in_channels, out_channels, kernel_size, stride=1 if upsample else stride
105         )
106         self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None
107         self.activation = nn.ReLU() if activation else None
108
109     def forward(self, input):
110         if self.upsample:
111             input = self.upsample(input)
112
113         output = self.conv(self.pad(input))
114
115         if self.norm:
116             output = self.norm(output)
117         if self.activation:
118             output = self.activation(output)
119
120         return output

It is common practice to append a few residual blocks after the initial convolutions to the encoder to enable it to learn more descriptive features.

128 class Residual(nn.Module):
129     def __init__(self, channels):
130         super().__init__()
131         self.conv1 = Conv(channels, channels, kernel_size=3)
132         self.conv2 = Conv(channels, channels, kernel_size=3, activation=False)
133
134     def forward(self, input):
135         output = self.conv2(self.conv1(input))
136         return output + input

It can be useful for the training to transform the input into another value range, for example from \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{1}\) to \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{255}\).

144 class FloatToUint8Range(nn.Module):
145     def forward(self, input):
146         return input * 255.0
147
148
149 class Uint8ToFloatRange(nn.Module):
150     def forward(self, input):
151         return input / 255.0

Finally, we can put all pieces together.

Note

You can access this transformer through pystiche.demo.transformer().

162 class Transformer(nn.Module):
163     def __init__(self):
164         super().__init__()
165         self.encoder = nn.Sequential(
166             Conv(3, 32, kernel_size=9),
167             Conv(32, 64, kernel_size=3, stride=2),
168             Conv(64, 128, kernel_size=3, stride=2),
169             Residual(128),
170             Residual(128),
171             Residual(128),
172             Residual(128),
173             Residual(128),
174         )
175         self.decoder = nn.Sequential(
176             Conv(128, 64, kernel_size=3, stride=2, upsample=True),
177             Conv(64, 32, kernel_size=3, stride=2, upsample=True),
178             Conv(32, 3, kernel_size=9, norm=False, activation=False),
179         )
180
181         self.preprocessor = FloatToUint8Range()
182         self.postprocessor = Uint8ToFloatRange()
183
184     def forward(self, input):
185         input = self.preprocessor(input)
186         output = self.decoder(self.encoder(input))
187         return self.postprocessor(output)
188
189
190 transformer = Transformer().to(device)
191 print(transformer)

Out:

Transformer(
  (encoder): Sequential(
    (0): Conv(
      (pad): ReflectionPad2d((4, 4, 4, 4))
      (conv): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
      (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (1): Conv(
      (pad): ReflectionPad2d((1, 1, 1, 1))
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (2): Conv(
      (pad): ReflectionPad2d((1, 1, 1, 1))
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
      (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (3): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
    (4): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
    (5): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
    (6): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
    (7): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
  )
  (decoder): Sequential(
    (0): Conv(
      (upsample): Interpolate(scale_factor=2)
      (pad): ReflectionPad2d((1, 1, 1, 1))
      (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
      (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (1): Conv(
      (upsample): Interpolate(scale_factor=2)
      (pad): ReflectionPad2d((1, 1, 1, 1))
      (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
      (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (2): Conv(
      (pad): ReflectionPad2d((4, 4, 4, 4))
      (conv): Conv2d(32, 3, kernel_size=(9, 9), stride=(1, 1))
    )
  )
  (preprocessor): FloatToUint8Range()
  (postprocessor): Uint8ToFloatRange()
)

Perceptual loss

Although model optimization is a different paradigm, the perceptual_loss is the same as for image optimization.

Note

In some implementations, such as the PyTorch example and [JAL16], one can observe that the gram_matrix(), used as style representation, is not only normalized by the height and width of the feature map, but also by the number of channels. If used together with a mse_loss(), the normalization is performed twice. While this is unintended, it affects the training. In order to keep the other hyper parameters on par with the PyTorch example, we also adopt this change here.

211 multi_layer_encoder = enc.vgg16_multi_layer_encoder()
212
213 content_layer = "relu2_2"
214 content_encoder = multi_layer_encoder.extract_encoder(content_layer)
215 content_weight = 1e5
216 content_loss = loss.FeatureReconstructionLoss(
217     content_encoder, score_weight=content_weight
218 )
219
220
221 class GramOperator(loss.GramLoss):
222     def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
223         repr = super().enc_to_repr(enc)
224         num_channels = repr.size()[1]
225         return repr / num_channels
226
227
228 style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
229 style_weight = 1e10
230 style_loss = loss.MultiLayerEncodingLoss(
231     multi_layer_encoder,
232     style_layers,
233     lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
234     layer_weights="sum",
235     score_weight=style_weight,
236 )
237
238 perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
239 perceptual_loss = perceptual_loss.to(device)
240 print(perceptual_loss)

Out:

PerceptualLoss(
  (content_loss): FeatureReconstructionLoss(
    score_weight=100000
    (encoder): VGGMultiLayerEncoder(layer=relu2_2, arch=vgg16, framework=torch)
  )
  (style_loss): MultiLayerEncodingLoss(
    encoder=VGGMultiLayerEncoder(arch=vgg16, framework=torch), score_weight=1e+10
    (relu1_2): GramOperator()
    (relu2_2): GramOperator()
    (relu3_3): GramOperator()
    (relu4_3): GramOperator()
  )
)

Training

In a first step we load the style image that will be used to train the transformer.

250 images = demo.images()
251 size = 500
252
253 style_image = images["paint"].read(size=size, device=device)
254 show_image(style_image)
example model optimization

The training of the transformer is performed similar to other models in PyTorch. In every optimization step a batch of content images is drawn from a dataset, which serve as input for the transformer as well as content_image for the perceptual_loss. While the style_image only has to be set once, the content_image has to be reset in every iteration step.

While this can be done with a boilerplate optimization loop, pystiche provides multi_epoch_model_optimization() that handles the above for you.

Note

If the perceptual_loss is a PerceptualLoss, as is the case here, the update of the content_image is performed automatically. If that is not the case or you need more complex update behavior, you need to specify a criterion_update_fn.

Note

If you do not specify an optimizer, the default_model_optimizer(), i.e. Adam is used.

281 def train(
282     transformer, dataset, batch_size=4, epochs=2,
283 ):
284     if dataset is None:
285         raise RuntimeError(
286             "You forgot to define a dataset. For example, "
287             "you can use any image dataset from torchvision.datasets."
288         )
289
290     from torch.utils.data import DataLoader
291
292     image_loader = DataLoader(dataset, batch_size=batch_size)
293
294     perceptual_loss.set_style_image(style_image)
295
296     return optim.multi_epoch_model_optimization(
297         image_loader, transformer.train(), perceptual_loss, epochs=epochs,
298     )

Depending on the dataset and your setup the training can take a couple of hours. To avoid this, we provide transformer weights that were trained with the scheme above.

Note

If you want to perform the training yourself, set use_pretrained_transformer=False. If you do, you also need to replace dataset = None below with the dataset you want to train on.

Note

The weights of the provided transformer were trained with the 2014 training images of the COCO dataset. The training was performed for num_epochs=2 and batch_size=4. Each image was center-cropped to 256 x 256 pixels.

319 use_pretrained_transformer = True
320 checkpoint = "example_transformer.pth"
321
322 if use_pretrained_transformer:
323     if path.exists(checkpoint):
324         state_dict = torch.load(checkpoint)
325     else:
326         # Unfortunately, torch.hub.load_state_dict_from_url has no option to disable
327         # printing the downloading process. Since this would clutter the output, we
328         # suppress it completely.
329         @contextlib.contextmanager
330         def suppress_output():
331             with open(os.devnull, "w") as devnull:
332                 with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(
333                     devnull
334                 ):
335                     yield
336
337         url = "https://download.pystiche.org/models/example_transformer.pth"
338
339         with suppress_output():
340             state_dict = hub.load_state_dict_from_url(url)
341
342     transformer.load_state_dict(state_dict)
343 else:
344     dataset = None
345     transformer = train(transformer, dataset)
346
347     state_dict = OrderedDict(
348         [
349             (name, parameter.detach().cpu())
350             for name, parameter in transformer.state_dict().items()
351         ]
352     )
353     torch.save(state_dict, checkpoint)

Neural Style Transfer

In order to perform the NST, we load an image we want to stylize.

362 input_image = images["bird1"].read(size=size, device=device)
363 show_image(input_image)
example model optimization

After the transformer is trained we can now perform an NST with a single forward pass. To do this, the transformer is simply called with the input_image.

370 transformer.eval()
371
372 start = time.time()
373
374 with torch.no_grad():
375     output_image = transformer(input_image)
376
377 stop = time.time()
378
379 show_image(output_image, title="Output image")
Output image

Compared to NST via image optimization, the stylization is performed multiple orders of magnitudes faster. Given capable hardware, NST via model optimization enables real-time stylization for example of a video feed.

388 print(f"The stylization took {(stop - start) * 1e3:.0f} milliseconds.")

Out:

The stylization took 3 milliseconds.

Total running time of the script: ( 0 minutes 1.600 seconds)

Estimated memory usage: 340 MB

Gallery generated by Sphinx-Gallery