Note
Click here to download the full example code
Model optimization¶
This example showcases how an NST based on model optimization can be performed in
pystiche
. It closely follows the
official PyTorch example
which in turn is based on [JAL16].
We start this example by importing everything we need and setting the device we will be working on.
16 import contextlib
17 import os
18 import time
19 from collections import OrderedDict
20 from os import path
21
22 import torch
23 from torch import hub, nn
24 from torch.nn.functional import interpolate
25
26 import pystiche
27 from pystiche import demo, enc, loss, optim
28 from pystiche.image import show_image
29 from pystiche.misc import get_device
30
31 print(f"I'm working with pystiche=={pystiche.__version__}")
32
33 device = get_device()
34 print(f"I'm working with {device}")
Out:
I'm working with pystiche==1.1.0.dev44+gd9e3fd8
I'm working with cuda
Transformer¶
In contrast to image optimization, for model optimization we need to define a transformer that, after it is trained, performs the stylization. In general different architectures are possible ([JAL16, ULVL16]). For this example we use an encoder-decoder architecture.
Before we define the transformer, we create some helper modules to reduce the clutter.
In the decoder we need to upsample the image. While it is possible to achieve this
with a ConvTranspose2d
, it was found that traditional upsampling
followed by a standard convolution
produces fewer artifacts. Thus,
we create an module that wraps torch.nn.functional.interpolate()
.
57 class Interpolate(nn.Module):
58 def __init__(self, scale_factor=1.0, mode="nearest"):
59 super().__init__()
60 self.scale_factor = scale_factor
61 self.mode = mode
62
63 def forward(self, input):
64 return interpolate(input, scale_factor=self.scale_factor, mode=self.mode,)
65
66 def extra_repr(self):
67 extras = []
68 if self.scale_factor:
69 extras.append(f"scale_factor={self.scale_factor}")
70 if self.mode != "nearest":
71 extras.append(f"mode={self.mode}")
72 return ", ".join(extras)
For the transformer architecture we will be using, we need to define a convolution module with some additional capabilities. In particular, it needs to be able to
optionally upsample the input,
pad the input in order for the convolution to be size-preserving,
optionally normalize the output, and
optionally pass the output through an activation function.
Note
Instead of BatchNorm2d
we use InstanceNorm2d
to normalize the output since it gives better results for NST [UVL16].
90 class Conv(nn.Module):
91 def __init__(
92 self,
93 in_channels,
94 out_channels,
95 kernel_size,
96 stride=1,
97 upsample=False,
98 norm=True,
99 activation=True,
100 ):
101 super().__init__()
102 self.upsample = Interpolate(scale_factor=stride) if upsample else None
103 self.pad = nn.ReflectionPad2d(kernel_size // 2)
104 self.conv = nn.Conv2d(
105 in_channels, out_channels, kernel_size, stride=1 if upsample else stride
106 )
107 self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None
108 self.activation = nn.ReLU() if activation else None
109
110 def forward(self, input):
111 if self.upsample:
112 input = self.upsample(input)
113
114 output = self.conv(self.pad(input))
115
116 if self.norm:
117 output = self.norm(output)
118 if self.activation:
119 output = self.activation(output)
120
121 return output
It is common practice to append a few residual blocks after the initial convolutions to the encoder to enable it to learn more descriptive features.
129 class Residual(nn.Module):
130 def __init__(self, channels):
131 super().__init__()
132 self.conv1 = Conv(channels, channels, kernel_size=3)
133 self.conv2 = Conv(channels, channels, kernel_size=3, activation=False)
134
135 def forward(self, input):
136 output = self.conv2(self.conv1(input))
137 return output + input
It can be useful for the training to transform the input into another value range, for example from \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{1}\) to \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{255}\).
145 class FloatToUint8Range(nn.Module):
146 def forward(self, input):
147 return input * 255.0
148
149
150 class Uint8ToFloatRange(nn.Module):
151 def forward(self, input):
152 return input / 255.0
Finally, we can put all pieces together.
Note
You can access this transformer through pystiche.demo.transformer()
.
163 class Transformer(nn.Module):
164 def __init__(self):
165 super().__init__()
166 self.encoder = nn.Sequential(
167 Conv(3, 32, kernel_size=9),
168 Conv(32, 64, kernel_size=3, stride=2),
169 Conv(64, 128, kernel_size=3, stride=2),
170 Residual(128),
171 Residual(128),
172 Residual(128),
173 Residual(128),
174 Residual(128),
175 )
176 self.decoder = nn.Sequential(
177 Conv(128, 64, kernel_size=3, stride=2, upsample=True),
178 Conv(64, 32, kernel_size=3, stride=2, upsample=True),
179 Conv(32, 3, kernel_size=9, norm=False, activation=False),
180 )
181
182 self.preprocessor = FloatToUint8Range()
183 self.postprocessor = Uint8ToFloatRange()
184
185 def forward(self, input):
186 input = self.preprocessor(input)
187 output = self.decoder(self.encoder(input))
188 return self.postprocessor(output)
189
190
191 transformer = Transformer().to(device)
192 print(transformer)
Out:
Transformer(
(encoder): Sequential(
(0): Conv(
(pad): ReflectionPad2d((4, 4, 4, 4))
(conv): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(3): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(4): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(5): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(6): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(7): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
)
(decoder): Sequential(
(0): Conv(
(upsample): Interpolate(scale_factor=2)
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(1): Conv(
(upsample): Interpolate(scale_factor=2)
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(2): Conv(
(pad): ReflectionPad2d((4, 4, 4, 4))
(conv): Conv2d(32, 3, kernel_size=(9, 9), stride=(1, 1))
)
)
(preprocessor): FloatToUint8Range()
(postprocessor): Uint8ToFloatRange()
)
Perceptual loss¶
Although model optimization is a different paradigm, the perceptual_loss
is the
same as for image optimization.
Note
In some implementations, such as the PyTorch example and [JAL16], one can
observe that the gram_matrix()
, used as style representation, is not
only normalized by the height and width of the feature map, but also by the number
of channels. If used together with a mse_loss()
, the
normalization is performed twice. While this is unintended, it affects the training.
In order to keep the other hyper parameters on par with the PyTorch example, we also
adopt this change here.
212 multi_layer_encoder = enc.vgg16_multi_layer_encoder()
213
214 content_layer = "relu2_2"
215 content_encoder = multi_layer_encoder.extract_encoder(content_layer)
216 content_weight = 1e5
217 content_loss = loss.FeatureReconstructionLoss(
218 content_encoder, score_weight=content_weight
219 )
220
221
222 class GramOperator(loss.GramLoss):
223 def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
224 repr = super().enc_to_repr(enc)
225 num_channels = repr.size()[1]
226 return repr / num_channels
227
228
229 style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
230 style_weight = 1e10
231 style_loss = loss.MultiLayerEncodingLoss(
232 multi_layer_encoder,
233 style_layers,
234 lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
235 layer_weights="sum",
236 score_weight=style_weight,
237 )
238
239 perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
240 perceptual_loss = perceptual_loss.to(device)
241 print(perceptual_loss)
Out:
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/runnerx/.cache/torch/hub/checkpoints/vgg16-397923af.pth
0%| | 0.00/528M [00:00<?, ?B/s]
3%|2 | 14.6M/528M [00:00<00:03, 153MB/s]
6%|5 | 30.3M/528M [00:00<00:03, 160MB/s]
9%|8 | 45.8M/528M [00:00<00:03, 161MB/s]
12%|#1 | 61.5M/528M [00:00<00:03, 162MB/s]
15%|#4 | 77.1M/528M [00:00<00:02, 163MB/s]
18%|#7 | 92.7M/528M [00:00<00:02, 163MB/s]
21%|## | 108M/528M [00:00<00:02, 163MB/s]
23%|##3 | 124M/528M [00:00<00:02, 163MB/s]
26%|##6 | 140M/528M [00:00<00:02, 163MB/s]
29%|##9 | 155M/528M [00:01<00:02, 164MB/s]
32%|###2 | 171M/528M [00:01<00:02, 164MB/s]
35%|###5 | 186M/528M [00:01<00:02, 164MB/s]
38%|###8 | 202M/528M [00:01<00:02, 164MB/s]
41%|####1 | 218M/528M [00:01<00:01, 164MB/s]
44%|####4 | 233M/528M [00:01<00:01, 164MB/s]
47%|####7 | 249M/528M [00:01<00:01, 163MB/s]
50%|##### | 264M/528M [00:01<00:01, 163MB/s]
53%|#####3 | 280M/528M [00:01<00:01, 163MB/s]
56%|#####6 | 296M/528M [00:01<00:01, 163MB/s]
59%|#####8 | 311M/528M [00:02<00:01, 163MB/s]
62%|######1 | 327M/528M [00:02<00:01, 163MB/s]
65%|######4 | 342M/528M [00:02<00:01, 163MB/s]
68%|######7 | 358M/528M [00:02<00:01, 163MB/s]
71%|####### | 374M/528M [00:02<00:01, 157MB/s]
74%|#######3 | 389M/528M [00:02<00:00, 160MB/s]
77%|#######6 | 405M/528M [00:02<00:00, 162MB/s]
80%|#######9 | 421M/528M [00:02<00:00, 163MB/s]
83%|########2 | 437M/528M [00:02<00:00, 164MB/s]
86%|########5 | 453M/528M [00:02<00:00, 165MB/s]
89%|########8 | 469M/528M [00:03<00:00, 166MB/s]
92%|#########1| 485M/528M [00:03<00:00, 166MB/s]
95%|#########4| 501M/528M [00:03<00:00, 166MB/s]
98%|#########7| 517M/528M [00:03<00:00, 166MB/s]
100%|##########| 528M/528M [00:03<00:00, 164MB/s]
PerceptualLoss(
(content_loss): FeatureReconstructionLoss(
score_weight=100000
(encoder): VGGMultiLayerEncoder(layer=relu2_2, arch=vgg16, framework=torch)
)
(style_loss): MultiLayerEncodingLoss(
encoder=VGGMultiLayerEncoder(arch=vgg16, framework=torch), score_weight=1e+10
(relu1_2): GramOperator()
(relu2_2): GramOperator()
(relu3_3): GramOperator()
(relu4_3): GramOperator()
)
)
Training¶
In a first step we load the style image that will be used to train the
transformer
.
251 images = demo.images()
252 size = 500
253
254 style_image = images["paint"].read(size=size, device=device)
255 show_image(style_image)
The training of the transformer
is performed similar to other models in PyTorch.
In every optimization step a batch of content images is drawn from a dataset, which
serve as input for the transformer as well as content_image
for the
perceptual_loss
. While the style_image
only has to be set once, the
content_image
has to be reset in every iteration step.
While this can be done with a boilerplate optimization loop, pystiche
provides
multi_epoch_model_optimization()
that handles the above for you.
Note
If the perceptual_loss
is a PerceptualLoss
, as is the
case here, the update of the content_image
is performed automatically. If that
is not the case or you need more complex update behavior, you need to specify a
criterion_update_fn
.
Note
If you do not specify an optimizer
, the
default_model_optimizer()
, i.e. Adam
is
used.
282 def train(
283 transformer, dataset, batch_size=4, epochs=2,
284 ):
285 if dataset is None:
286 raise RuntimeError(
287 "You forgot to define a dataset. For example, "
288 "you can use any image dataset from torchvision.datasets."
289 )
290
291 from torch.utils.data import DataLoader
292
293 image_loader = DataLoader(dataset, batch_size=batch_size)
294
295 perceptual_loss.set_style_image(style_image)
296
297 return optim.multi_epoch_model_optimization(
298 image_loader, transformer.train(), perceptual_loss, epochs=epochs,
299 )
Depending on the dataset and your setup the training can take a couple of hours. To avoid this, we provide transformer weights that were trained with the scheme above.
Note
If you want to perform the training yourself, set
use_pretrained_transformer=False
. If you do, you also need to replace
dataset = None
below with the dataset you want to train on.
Note
The weights of the provided transformer were trained with the
2014 training images of the
COCO dataset. The training was performed for
num_epochs=2
and batch_size=4
. Each image was center-cropped to
256 x 256
pixels.
320 use_pretrained_transformer = True
321 checkpoint = "example_transformer.pth"
322
323 if use_pretrained_transformer:
324 if path.exists(checkpoint):
325 state_dict = torch.load(checkpoint)
326 else:
327 # Unfortunately, torch.hub.load_state_dict_from_url has no option to disable
328 # printing the downloading process. Since this would clutter the output, we
329 # suppress it completely.
330 @contextlib.contextmanager
331 def suppress_output():
332 with open(os.devnull, "w") as devnull:
333 with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(
334 devnull
335 ):
336 yield
337
338 url = "https://download.pystiche.org/models/example_transformer.pth"
339
340 with suppress_output():
341 state_dict = hub.load_state_dict_from_url(url)
342
343 transformer.load_state_dict(state_dict)
344 else:
345 dataset = None
346 transformer = train(transformer, dataset)
347
348 state_dict = OrderedDict(
349 [
350 (name, parameter.detach().cpu())
351 for name, parameter in transformer.state_dict().items()
352 ]
353 )
354 torch.save(state_dict, checkpoint)
Neural Style Transfer¶
In order to perform the NST, we load an image we want to stylize.
363 input_image = images["bird1"].read(size=size, device=device)
364 show_image(input_image)
After the transformer is trained we can now perform an NST with a single forward pass.
To do this, the transformer
is simply called with the input_image
.
371 transformer.eval()
372
373 start = time.time()
374
375 with torch.no_grad():
376 output_image = transformer(input_image)
377
378 stop = time.time()
379
380 show_image(output_image, title="Output image")
Compared to NST via image optimization, the stylization is performed multiple orders of magnitudes faster. Given capable hardware, NST via model optimization enables real-time stylization for example of a video feed.
389 print(f"The stylization took {(stop - start) * 1e3:.0f} milliseconds.")
Out:
The stylization took 6 milliseconds.
Total running time of the script: ( 0 minutes 6.538 seconds)
Estimated memory usage: 340 MB