Note
Click here to download the full example code
Model optimization¶
This example showcases how an NST based on model optimization can be performed in
pystiche
. It closely follows the
official PyTorch example
which in turn is based on [JAL16].
We start this example by importing everything we need and setting the device we will be working on.
16 import contextlib
17 import os
18 import time
19 from collections import OrderedDict
20 from os import path
21
22 import torch
23 from torch import hub, nn
24 from torch.nn.functional import interpolate
25
26 import pystiche
27 from pystiche import demo, enc, loss, optim
28 from pystiche.image import show_image
29 from pystiche.misc import get_device
30
31 print(f"I'm working with pystiche=={pystiche.__version__}")
32
33 device = get_device()
34 print(f"I'm working with {device}")
Out:
I'm working with pystiche==1.0.1
I'm working with cuda
Transformer¶
In contrast to image optimization, for model optimization we need to define a transformer that, after it is trained, performs the stylization. In general different architectures are possible ([JAL16, ULVL16]). For this example we use an encoder-decoder architecture.
Before we define the transformer, we create some helper modules to reduce the clutter.
In the decoder we need to upsample the image. While it is possible to achieve this
with a ConvTranspose2d
, it was found that traditional upsampling
followed by a standard convolution
produces fewer artifacts. Thus,
we create an module that wraps interpolate()
.
57 class Interpolate(nn.Module):
58 def __init__(self, scale_factor=1.0, mode="nearest"):
59 super().__init__()
60 self.scale_factor = scale_factor
61 self.mode = mode
62
63 def forward(self, input):
64 return interpolate(input, scale_factor=self.scale_factor, mode=self.mode,)
65
66 def extra_repr(self):
67 extras = []
68 if self.scale_factor:
69 extras.append(f"scale_factor={self.scale_factor}")
70 if self.mode != "nearest":
71 extras.append(f"mode={self.mode}")
72 return ", ".join(extras)
For the transformer architecture we will be using, we need to define a convolution module with some additional capabilities. In particular, it needs to be able to - optionally upsample the input, - pad the input in order for the convolution to be size-preserving, - optionally normalize the output, and - optionally pass the output through an activation function.
Note
Instead of BatchNorm2d
we use InstanceNorm2d
to normalize the output since it gives better results for NST [UVL16].
89 class Conv(nn.Module):
90 def __init__(
91 self,
92 in_channels,
93 out_channels,
94 kernel_size,
95 stride=1,
96 upsample=False,
97 norm=True,
98 activation=True,
99 ):
100 super().__init__()
101 self.upsample = Interpolate(scale_factor=stride) if upsample else None
102 self.pad = nn.ReflectionPad2d(kernel_size // 2)
103 self.conv = nn.Conv2d(
104 in_channels, out_channels, kernel_size, stride=1 if upsample else stride
105 )
106 self.norm = nn.InstanceNorm2d(out_channels, affine=True) if norm else None
107 self.activation = nn.ReLU() if activation else None
108
109 def forward(self, input):
110 if self.upsample:
111 input = self.upsample(input)
112
113 output = self.conv(self.pad(input))
114
115 if self.norm:
116 output = self.norm(output)
117 if self.activation:
118 output = self.activation(output)
119
120 return output
It is common practice to append a few residual blocks after the initial convolutions to the encoder to enable it to learn more descriptive features.
128 class Residual(nn.Module):
129 def __init__(self, channels):
130 super().__init__()
131 self.conv1 = Conv(channels, channels, kernel_size=3)
132 self.conv2 = Conv(channels, channels, kernel_size=3, activation=False)
133
134 def forward(self, input):
135 output = self.conv2(self.conv1(input))
136 return output + input
It can be useful for the training to transform the input into another value range, for example from \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{1}\) to \(\newcommand{\parentheses}[1]{\left( #1 \right)} \newcommand{\brackets}[1]{\left[ #1 \right]} \newcommand{\mean}[1][]{\overline{\sum #1}} \newcommand{\fun}[2]{\text{#1}\of{#2}} \newcommand{\of}[1]{\parentheses{#1}} \newcommand{\dotproduct}[2]{\left\langle #1 , #2 \right\rangle} \newcommand{\openinterval}[2]{\parentheses{#1, #2}} \newcommand{\closedinterval}[2]{\brackets{#1, #2}} \closedinterval{0}{255}\).
144 class FloatToUint8Range(nn.Module):
145 def forward(self, input):
146 return input * 255.0
147
148
149 class Uint8ToFloatRange(nn.Module):
150 def forward(self, input):
151 return input / 255.0
Finally, we can put all pieces together.
Note
You can access this transformer through pystiche.demo.transformer()
.
162 class Transformer(nn.Module):
163 def __init__(self):
164 super().__init__()
165 self.encoder = nn.Sequential(
166 Conv(3, 32, kernel_size=9),
167 Conv(32, 64, kernel_size=3, stride=2),
168 Conv(64, 128, kernel_size=3, stride=2),
169 Residual(128),
170 Residual(128),
171 Residual(128),
172 Residual(128),
173 Residual(128),
174 )
175 self.decoder = nn.Sequential(
176 Conv(128, 64, kernel_size=3, stride=2, upsample=True),
177 Conv(64, 32, kernel_size=3, stride=2, upsample=True),
178 Conv(32, 3, kernel_size=9, norm=False, activation=False),
179 )
180
181 self.preprocessor = FloatToUint8Range()
182 self.postprocessor = Uint8ToFloatRange()
183
184 def forward(self, input):
185 input = self.preprocessor(input)
186 output = self.decoder(self.encoder(input))
187 return self.postprocessor(output)
188
189
190 transformer = Transformer().to(device)
191 print(transformer)
Out:
Transformer(
(encoder): Sequential(
(0): Conv(
(pad): ReflectionPad2d((4, 4, 4, 4))
(conv): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(3): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(4): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(5): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(6): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(7): Residual(
(conv1): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(conv2): Conv(
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
)
(decoder): Sequential(
(0): Conv(
(upsample): Interpolate(scale_factor=2)
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(1): Conv(
(upsample): Interpolate(scale_factor=2)
(pad): ReflectionPad2d((1, 1, 1, 1))
(conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(activation): ReLU()
)
(2): Conv(
(pad): ReflectionPad2d((4, 4, 4, 4))
(conv): Conv2d(32, 3, kernel_size=(9, 9), stride=(1, 1))
)
)
(preprocessor): FloatToUint8Range()
(postprocessor): Uint8ToFloatRange()
)
Perceptual loss¶
Although model optimization is a different paradigm, the perceptual_loss
is the
same as for image optimization.
Note
In some implementations, such as the PyTorch example and [JAL16], one can
observe that the gram_matrix()
, used as style representation, is not
only normalized by the height and width of the feature map, but also by the number
of channels. If used together with a mse_loss()
, the
normalization is performed twice. While this is unintended, it affects the training.
In order to keep the other hyper parameters on par with the PyTorch example, we also
adopt this change here.
211 multi_layer_encoder = enc.vgg16_multi_layer_encoder()
212
213 content_layer = "relu2_2"
214 content_encoder = multi_layer_encoder.extract_encoder(content_layer)
215 content_weight = 1e5
216 content_loss = loss.FeatureReconstructionLoss(
217 content_encoder, score_weight=content_weight
218 )
219
220
221 class GramOperator(loss.GramLoss):
222 def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
223 repr = super().enc_to_repr(enc)
224 num_channels = repr.size()[1]
225 return repr / num_channels
226
227
228 style_layers = ("relu1_2", "relu2_2", "relu3_3", "relu4_3")
229 style_weight = 1e10
230 style_loss = loss.MultiLayerEncodingLoss(
231 multi_layer_encoder,
232 style_layers,
233 lambda encoder, layer_weight: GramOperator(encoder, score_weight=layer_weight),
234 layer_weights="sum",
235 score_weight=style_weight,
236 )
237
238 perceptual_loss = loss.PerceptualLoss(content_loss, style_loss)
239 perceptual_loss = perceptual_loss.to(device)
240 print(perceptual_loss)
Out:
PerceptualLoss(
(content_loss): FeatureReconstructionLoss(
score_weight=100000
(encoder): VGGMultiLayerEncoder(layer=relu2_2, arch=vgg16, framework=torch)
)
(style_loss): MultiLayerEncodingLoss(
encoder=VGGMultiLayerEncoder(arch=vgg16, framework=torch), score_weight=1e+10
(relu1_2): GramOperator()
(relu2_2): GramOperator()
(relu3_3): GramOperator()
(relu4_3): GramOperator()
)
)
Training¶
In a first step we load the style image that will be used to train the
transformer
.
250 images = demo.images()
251 size = 500
252
253 style_image = images["paint"].read(size=size, device=device)
254 show_image(style_image)
The training of the transformer
is performed similar to other models in PyTorch.
In every optimization step a batch of content images is drawn from a dataset, which
serve as input for the transformer as well as content_image
for the
perceptual_loss
. While the style_image
only has to be set once, the
content_image
has to be reset in every iteration step.
While this can be done with a boilerplate optimization loop, pystiche
provides
multi_epoch_model_optimization()
that handles the above for you.
Note
If the perceptual_loss
is a PerceptualLoss
, as is the
case here, the update of the content_image
is performed automatically. If that
is not the case or you need more complex update behavior, you need to specify a
criterion_update_fn
.
Note
If you do not specify an optimizer
, the
default_model_optimizer()
, i.e. Adam
is
used.
281 def train(
282 transformer, dataset, batch_size=4, epochs=2,
283 ):
284 if dataset is None:
285 raise RuntimeError(
286 "You forgot to define a dataset. For example, "
287 "you can use any image dataset from torchvision.datasets."
288 )
289
290 from torch.utils.data import DataLoader
291
292 image_loader = DataLoader(dataset, batch_size=batch_size)
293
294 perceptual_loss.set_style_image(style_image)
295
296 return optim.multi_epoch_model_optimization(
297 image_loader, transformer.train(), perceptual_loss, epochs=epochs,
298 )
Depending on the dataset and your setup the training can take a couple of hours. To avoid this, we provide transformer weights that were trained with the scheme above.
Note
If you want to perform the training yourself, set
use_pretrained_transformer=False
. If you do, you also need to replace
dataset = None
below with the dataset you want to train on.
Note
The weights of the provided transformer were trained with the
2014 training images of the
COCO dataset. The training was performed for
num_epochs=2
and batch_size=4
. Each image was center-cropped to
256 x 256
pixels.
319 use_pretrained_transformer = True
320 checkpoint = "example_transformer.pth"
321
322 if use_pretrained_transformer:
323 if path.exists(checkpoint):
324 state_dict = torch.load(checkpoint)
325 else:
326 # Unfortunately, torch.hub.load_state_dict_from_url has no option to disable
327 # printing the downloading process. Since this would clutter the output, we
328 # suppress it completely.
329 @contextlib.contextmanager
330 def suppress_output():
331 with open(os.devnull, "w") as devnull:
332 with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(
333 devnull
334 ):
335 yield
336
337 url = "https://download.pystiche.org/models/example_transformer.pth"
338
339 with suppress_output():
340 state_dict = hub.load_state_dict_from_url(url)
341
342 transformer.load_state_dict(state_dict)
343 else:
344 dataset = None
345 transformer = train(transformer, dataset)
346
347 state_dict = OrderedDict(
348 [
349 (name, parameter.detach().cpu())
350 for name, parameter in transformer.state_dict().items()
351 ]
352 )
353 torch.save(state_dict, checkpoint)
Neural Style Transfer¶
In order to perform the NST, we load an image we want to stylize.
362 input_image = images["bird1"].read(size=size, device=device)
363 show_image(input_image)
After the transformer is trained we can now perform an NST with a single forward pass.
To do this, the transformer
is simply called with the input_image
.
370 transformer.eval()
371
372 start = time.time()
373
374 with torch.no_grad():
375 output_image = transformer(input_image)
376
377 stop = time.time()
378
379 show_image(output_image, title="Output image")
Compared to NST via image optimization, the stylization is performed multiple orders of magnitudes faster. Given capable hardware, NST via model optimization enables real-time stylization for example of a video feed.
388 print(f"The stylization took {(stop - start) * 1e3:.0f} milliseconds.")
Out:
The stylization took 3 milliseconds.
Total running time of the script: ( 0 minutes 1.600 seconds)
Estimated memory usage: 340 MB