Model optimization

This example showcases how an NST based on model optimization can be performed in pystiche. It closely follows the official PyTorch example which in turn is based on [JAL16].

We start this example by importing everything we need and setting the device we will be working on.

16 import contextlib
17 import os
18 import time
19 from collections import OrderedDict
20 from os import path
21
22 import torch
23 from torch import hub, nn
24 from torch.nn.functional import interpolate
25
26 import pystiche
27 from pystiche import demo, enc, loss, optim
28 from pystiche.image import show_image
29 from pystiche.misc import get_device
30
31 print(f"I'm working with pystiche=={pystiche.__version__}")
32
33 device = get_device()
34 print(f"I'm working with {device}")

Out:

I'm working with pystiche==1.1.0.dev44+gd9e3fd8
I'm working with cuda

Transformer

In contrast to image optimization, for model optimization we need to define a transformer that, after it is trained, performs the stylization. In general different architectures are possible ([JAL16, ULVL16]). For this example we use an encoder-decoder architecture.

Before we define the transformer, we create some helper modules to reduce the clutter.

In the decoder we need to upsample the image. While it is possible to achieve this with a ConvTranspose2d, it was found that traditional upsampling followed by a standard convolution produces fewer artifacts. Thus, we create an module that wraps torch.nn.functional.interpolate().

57 class Interpolate(nn.Module):
58     def __init__(self, scale_factor=1.0, mode="nearest"):
59         super().__init__()
60         self.scale_factor = scale_factor
61         self.mode = mode
62
63     def forward(self, input):
64         return interpolate(input, scale_factor=self.scale_factor, mode=self.mode,)
65
66     def extra_repr(self):
67         extras = []
68         if self.scale_factor:
69             extras.append(f"scale_factor={self.scale_factor}")
70         if self.mode != "nearest":
71             extras.append(f"mode={self.mode}")
72         return ", ".join(extras)

For the transformer architecture we will be using, we need to define a convolution module with some additional capabilities. In particular, it needs to be able to

  • optionally upsample the input,

  • pad the input in order for the convolution to be size-preserving,

  • optionally normalize the output, and

  • optionally pass the output through an activation function.

Note

Instead of BatchNorm2d we use InstanceNorm2d to normalize the output since it gives better results for NST [UVL16].

 90 class Conv(nn.Module):
 91     def __init__(
 92         self,
 93         in_channels,
 94         out_channels,
 95         kernel_size,
 96         stride=1,
 97         upsample=False,
 98         norm=True,
 99         activation=True,
100     ):
101         super().__init__()
102         self.upsample = Interpolate(scale_factor=stride) if upsample else None
103         self.pad = nn.ReflectionPad2d(kernel_size // 2)
104         self.conv = nn.Conv2d(
105             in_channels, out_channels, kernel_size, stride=1 if upsample else stride
106         )
107         self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None
108         self.activation = nn.ReLU() if activation else None
109
110     def forward(self, input):
111         if self.upsample:
112             input = self.upsample(input)
113
114         output = self.conv(self.pad(input))
115
116         if self.norm:
117             output = self.norm(output)
118         if self.activation:
119             output = self.activation(output)
120
121         return output

It is common practice to append a few residual blocks after the initial convolutions to the encoder to enable it to learn more descriptive features.

129 class Residual(nn.Module):
130     def __init__(self, channels):
131         super().__init__()
132         self.conv1 = Conv(channels, channels, kernel_size=3)
133         self.conv2 = Conv(channels, channels, kernel_size=3, activation=False)
134
135     def forward(self, input):
136         output = self.conv2(self.conv1(input))
137         return output + input

It can be useful for the training to transform the input into another value range, for example from \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{1}\) to \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{255}\).

145 class FloatToUint8Range(nn.Module):
146     def forward(self, input):
147         return input * 255.0
148
149
150 class Uint8ToFloatRange(nn.Module):
151     def forward(self, input):
152         return input / 255.0

Finally, we can put all pieces together.

Note

You can access this transformer through pystiche.demo.transformer().

163 class Transformer(nn.Module):
164     def __init__(self):
165         super().__init__()
166         self.encoder = nn.Sequential(
167             Conv(3, 32, kernel_size=9),
168             Conv(32, 64, kernel_size=3, stride=2),
169             Conv(64, 128, kernel_size=3, stride=2),
170             Residual(128),
171             Residual(128),
172             Residual(128),
173             Residual(128),
174             Residual(128),
175         )
176         self.decoder = nn.Sequential(
177             Conv(128, 64, kernel_size=3, stride=2, upsample=True),
178             Conv(64, 32, kernel_size=3, stride=2, upsample=True),
179             Conv(32, 3, kernel_size=9, norm=False, activation=False),
180         )
181
182         self.preprocessor = FloatToUint8Range()
183         self.postprocessor = Uint8ToFloatRange()
184
185     def forward(self, input):
186         input = self.preprocessor(input)
187         output = self.decoder(self.encoder(input))
188         return self.postprocessor(output)
189
190
191 transformer = Transformer().to(device)
192 print(transformer)

Out:

Transformer(
  (encoder): Sequential(
    (0): Conv(
      (pad): ReflectionPad2d((4, 4, 4, 4))
      (conv): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
      (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (1): Conv(
      (pad): ReflectionPad2d((1, 1, 1, 1))
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (2): Conv(
      (pad): ReflectionPad2d((1, 1, 1, 1))
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
      (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (3): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
    (4): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
    (5): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
    (6): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
    (7): Residual(
      (conv1): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (activation): ReLU()
      )
      (conv2): Conv(
        (pad): ReflectionPad2d((1, 1, 1, 1))
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      )
    )
  )
  (decoder): Sequential(
    (0): Conv(
      (upsample): Interpolate(scale_factor=2)
      (pad): ReflectionPad2d((1, 1, 1, 1))
      (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
      (norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (1): Conv(
      (upsample): Interpolate(scale_factor=2)
      (pad): ReflectionPad2d((1, 1, 1, 1))
      (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
      (norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (activation): ReLU()
    )
    (2): Conv(
      (pad): ReflectionPad2d((4, 4, 4, 4))
      (conv): Conv2d(32, 3, kernel_size=(9, 9), stride=(1, 1))
    )
  )
  (preprocessor): FloatToUint8Range()
  (postprocessor): Uint8ToFloatRange()
)

Perceptual loss

Although model optimization is a different paradigm, the perceptual_loss is the same as for image optimization.

Note

In some implementations, such as the PyTorch example and [JAL16], one can observe that the gram_matrix(), used as style representation, is not only normalized by the height and width of the feature map, but also by the number of channels. If used together with a mse_loss(), the normalization is performed twice. While this is unintended, it affects the training. In order to keep the other hyper parameters on par with the PyTorch example, we also adopt this change here.

212 multi_layer_encoder = enc.vgg16_multi_layer_encoder()
213
214 content_layer = "relu2_2"
215 content_encoder = multi_layer_encoder.extract_encoder(content_layer)
216 content_weight = 1e5
217 content_loss = loss.FeatureReconstructionLoss(
218     content_encoder, score_weight=content_weight
219 )
220
221
222 class GramOperator(loss.GramLoss):
223     def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
224         repr = super().enc_to_repr(enc)
225         num_channels = repr.size()[1]
226         return repr / num_channels
227
228
229 style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
230 style_weight = 1e10
231 style_loss = loss.MultiLayerEncodingLoss(
232     multi_layer_encoder,
233     style_layers,
234     lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
235     layer_weights="sum",
236     score_weight=style_weight,
237 )
238
239 perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
240 perceptual_loss = perceptual_loss.to(device)
241 print(perceptual_loss)

Out:

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/runnerx/.cache/torch/hub/checkpoints/vgg16-397923af.pth

  0%|          | 0.00/528M [00:00<?, ?B/s]
  3%|2         | 14.6M/528M [00:00<00:03, 153MB/s]
  6%|5         | 30.3M/528M [00:00<00:03, 160MB/s]
  9%|8         | 45.8M/528M [00:00<00:03, 161MB/s]
 12%|#1        | 61.5M/528M [00:00<00:03, 162MB/s]
 15%|#4        | 77.1M/528M [00:00<00:02, 163MB/s]
 18%|#7        | 92.7M/528M [00:00<00:02, 163MB/s]
 21%|##        | 108M/528M [00:00<00:02, 163MB/s]
 23%|##3       | 124M/528M [00:00<00:02, 163MB/s]
 26%|##6       | 140M/528M [00:00<00:02, 163MB/s]
 29%|##9       | 155M/528M [00:01<00:02, 164MB/s]
 32%|###2      | 171M/528M [00:01<00:02, 164MB/s]
 35%|###5      | 186M/528M [00:01<00:02, 164MB/s]
 38%|###8      | 202M/528M [00:01<00:02, 164MB/s]
 41%|####1     | 218M/528M [00:01<00:01, 164MB/s]
 44%|####4     | 233M/528M [00:01<00:01, 164MB/s]
 47%|####7     | 249M/528M [00:01<00:01, 163MB/s]
 50%|#####     | 264M/528M [00:01<00:01, 163MB/s]
 53%|#####3    | 280M/528M [00:01<00:01, 163MB/s]
 56%|#####6    | 296M/528M [00:01<00:01, 163MB/s]
 59%|#####8    | 311M/528M [00:02<00:01, 163MB/s]
 62%|######1   | 327M/528M [00:02<00:01, 163MB/s]
 65%|######4   | 342M/528M [00:02<00:01, 163MB/s]
 68%|######7   | 358M/528M [00:02<00:01, 163MB/s]
 71%|#######   | 374M/528M [00:02<00:01, 157MB/s]
 74%|#######3  | 389M/528M [00:02<00:00, 160MB/s]
 77%|#######6  | 405M/528M [00:02<00:00, 162MB/s]
 80%|#######9  | 421M/528M [00:02<00:00, 163MB/s]
 83%|########2 | 437M/528M [00:02<00:00, 164MB/s]
 86%|########5 | 453M/528M [00:02<00:00, 165MB/s]
 89%|########8 | 469M/528M [00:03<00:00, 166MB/s]
 92%|#########1| 485M/528M [00:03<00:00, 166MB/s]
 95%|#########4| 501M/528M [00:03<00:00, 166MB/s]
 98%|#########7| 517M/528M [00:03<00:00, 166MB/s]
100%|##########| 528M/528M [00:03<00:00, 164MB/s]
PerceptualLoss(
  (content_loss): FeatureReconstructionLoss(
    score_weight=100000
    (encoder): VGGMultiLayerEncoder(layer=relu2_2, arch=vgg16, framework=torch)
  )
  (style_loss): MultiLayerEncodingLoss(
    encoder=VGGMultiLayerEncoder(arch=vgg16, framework=torch), score_weight=1e+10
    (relu1_2): GramOperator()
    (relu2_2): GramOperator()
    (relu3_3): GramOperator()
    (relu4_3): GramOperator()
  )
)

Training

In a first step we load the style image that will be used to train the transformer.

251 images = demo.images()
252 size = 500
253
254 style_image = images["paint"].read(size=size, device=device)
255 show_image(style_image)

The training of the transformer is performed similar to other models in PyTorch. In every optimization step a batch of content images is drawn from a dataset, which serve as input for the transformer as well as content_image for the perceptual_loss. While the style_image only has to be set once, the content_image has to be reset in every iteration step.

While this can be done with a boilerplate optimization loop, pystiche provides multi_epoch_model_optimization() that handles the above for you.

Note

If the perceptual_loss is a PerceptualLoss, as is the case here, the update of the content_image is performed automatically. If that is not the case or you need more complex update behavior, you need to specify a criterion_update_fn.

Note

If you do not specify an optimizer, the default_model_optimizer(), i.e. Adam is used.

282 def train(
283     transformer, dataset, batch_size=4, epochs=2,
284 ):
285     if dataset is None:
286         raise RuntimeError(
287             "You forgot to define a dataset. For example, "
288             "you can use any image dataset from torchvision.datasets."
289         )
290
291     from torch.utils.data import DataLoader
292
293     image_loader = DataLoader(dataset, batch_size=batch_size)
294
295     perceptual_loss.set_style_image(style_image)
296
297     return optim.multi_epoch_model_optimization(
298         image_loader, transformer.train(), perceptual_loss, epochs=epochs,
299     )

Depending on the dataset and your setup the training can take a couple of hours. To avoid this, we provide transformer weights that were trained with the scheme above.

Note

If you want to perform the training yourself, set use_pretrained_transformer=False. If you do, you also need to replace dataset = None below with the dataset you want to train on.

Note

The weights of the provided transformer were trained with the 2014 training images of the COCO dataset. The training was performed for num_epochs=2 and batch_size=4. Each image was center-cropped to 256 x 256 pixels.

320 use_pretrained_transformer = True
321 checkpoint = "example_transformer.pth"
322
323 if use_pretrained_transformer:
324     if path.exists(checkpoint):
325         state_dict = torch.load(checkpoint)
326     else:
327         # Unfortunately, torch.hub.load_state_dict_from_url has no option to disable
328         # printing the downloading process. Since this would clutter the output, we
329         # suppress it completely.
330         @contextlib.contextmanager
331         def suppress_output():
332             with open(os.devnull, "w") as devnull:
333                 with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(
334                     devnull
335                 ):
336                     yield
337
338         url = "https://download.pystiche.org/models/example_transformer.pth"
339
340         with suppress_output():
341             state_dict = hub.load_state_dict_from_url(url)
342
343     transformer.load_state_dict(state_dict)
344 else:
345     dataset = None
346     transformer = train(transformer, dataset)
347
348     state_dict = OrderedDict(
349         [
350             (name, parameter.detach().cpu())
351             for name, parameter in transformer.state_dict().items()
352         ]
353     )
354     torch.save(state_dict, checkpoint)

Neural Style Transfer

In order to perform the NST, we load an image we want to stylize.

363 input_image = images["bird1"].read(size=size, device=device)
364 show_image(input_image)

After the transformer is trained we can now perform an NST with a single forward pass. To do this, the transformer is simply called with the input_image.

371 transformer.eval()
372
373 start = time.time()
374
375 with torch.no_grad():
376     output_image = transformer(input_image)
377
378 stop = time.time()
379
380 show_image(output_image, title="Output image")

Compared to NST via image optimization, the stylization is performed multiple orders of magnitudes faster. Given capable hardware, NST via model optimization enables real-time stylization for example of a video feed.

389 print(f"The stylization took {(stop - start) * 1e3:.0f} milliseconds.")

Out:

The stylization took 6 milliseconds.

Total running time of the script: ( 0 minutes 6.538 seconds)

Estimated memory usage: 340 MB

Gallery generated by Sphinx-Gallery