In my previous article, I covered the ProgressiveGAN or ProGAN. This article explores an extension to the ProGAN known as the StyleGAN. The StyleGAN is an interesting diversion from some of the traditional techniques that modern generative AI algorithms use, and yet its promising results necessitate a look at the workings behind it.
The StyleGAN follows the GAN by roughly 5 years, and relies heavily on top of the fundamental backbone of the GAN. If you’d like a review of the concepts behind the GAN, you can check out my earlier article here - for this article I’ll only provide a quick review before going through the theory. In those 5 years, a lot of research had been done on GANs, but the inner workings of the generator models were still largely unknown to researchers.
The StyleGAN is a unique attempt by researchers to introduce some structure to the workings of the generator model. It draws from traditional style transfer literature and deals with a learned concept of “style”, which is the focus of the operations within the generator model. Let’s see how this works in practice.
As in the traditional GAN architecture, the StyleGAN consists of a generator model
and a discriminator model
. The objective of the GAN is to find a satisfactory equilibrium for the joint value function
In the original article, we discussed how this could be simplified into computations using binary cross entropy.
For the StyleGAN, the discriminator and loss function remain unchanged; only the generator model will require changes. The original StyleGAN uses the discriminator architecture from the ProGAN. I’ll reuse that implementation from my last article.
The first major addition of the StyleGAN is a learned constant latent. One aspect of a GAN that is often overlooked is the effect of the input on the rest of the network. In the traditional setup, the input is only used once, at the very beginning of the network in a feedforward layer.
In the StyleGAN, we’re going to attempt to include information from the latent at multiple stages of the network. In essence, we’ll try and learn an intermediate representation of “style” that’ll be used throughout the progressive growing stages of our network. To that end, our first addition is to create a small fully connected network to map from our input latent to the intermediate latent .
We’ll implement the idea of equalized learning rates from the original ProGAN paper. As a review, this involves scaling the features of a layer by a factor of . To create a fully connected layer based off this idea, we’ll simply scale the features by - since a linear layer doesn’t have kernels we don’t need the extra associated factor.
The next component of the StyleGAN is the addition of Adaptive Instance Normalization or AdaIN. AdaIN is an idea developed separately from the StyleGAN for the purposes of style transfer in images. To understand AdaIN, it’ll be helpful to review what Instance Normalization itself is. Instance Normalization (IN) is based off Batch Normalization (BN). However, instead of normalizing over a batch of inputs and features, IN only normalizes over the features of an individual sample . In this way, it’s more consistent than BN at inference time since BN requires the use of popular statistics derived from the training data. Furthermore, it has been shown in the original IN paper that IN is more suitable for learning features associated with the style of images.
In IN, we normalize over the features of an individual sample
as follows.
where
and
are computed over each feature channel of the sample
and applied independently. Adaptive IN extends this idea by including an additional input
, meant to represent the style to be adapted to. Instead of learning affine parameters
and
, we simply match
to the style input
as follows
Written this way, AdaIN requires no learnable affine parameters. We simply require a suitable style input
. The key idea behind the StyleGAN is that we can learn this input
using the intermediate latent space
from earlier.
Finally, the authors introduce a way to provide the generator with a more formulaic way to add details into the images provided by injecting uncorrelated Gaussian noise into the network periodically. This noise is scaled with a learned weight and serves as a way to untangle the parts of the network concerned with generating features from the parts concerned with making them realistic and in line with the style inputs.
First, we’ll implement the equalized learning rate we discussed earlier for the linear layer. This will look very similar to our implementation of the WSConv2d
layer from last article.
class WSLinear(nn.Module):
def __init__(self, in_features, out_features):
super(WSLinear, self).__init__()
self.linear = nn.Linear(in_features, out_features)
self.scale = (2 / in_features)**0.5
self.bias = self.linear.bias
self.linear.bias = None
nn.init.normal_(self.linear.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return self.linear(x * self.scale) + self.bias
Our network to map
to
will consist of
such layers, as per the paper. The full NoiseMappingNetwork
is then pretty simple -
class NoiseMappingNetwork(nn.Module):
def __init__(self, args):
super(NoiseMappingNetwork, self).__init__()
self.args = args
self.model = nn.Sequential(
PixelNorm(),
WSLinear(args.latent, args.w_latent),
nn.ReLU(),
WSLinear(args.w_latent, args.w_latent),
nn.ReLU(),
WSLinear(args.w_latent, args.w_latent),
nn.ReLU(),
WSLinear(args.w_latent, args.w_latent),
nn.ReLU(),
WSLinear(args.w_latent, args.w_latent),
nn.ReLU(),
WSLinear(args.w_latent, args.w_latent),
nn.ReLU(),
WSLinear(args.w_latent, args.w_latent),
nn.ReLU(),
WSLinear(args.w_latent, args.w_latent)
)
def forward(self, x):
return self.model(x)
Second, we have to implement the Adaptive Instance Normalization layer. You may recall that AdaIN requires the use of a style input that we’ll match the statistics of our actual input to. Instead of generating and then computing and over each feature map, we’ll borrow a trick from the VAE here and simply add a layer to map the intermediate latent to generate and directly. This works similar to how we represent a normal distribution through the reparameterization trick in the VAE.
PyTorch already has a very convenient InstanceNormalization
layer, so we’ll simply extend on top of this for the AdaptiveInstanceNormalization
layer. Instead of
directly, our AdaptiveInstanceNormalization
will take as input
and compute the required statistics for AdaIN. PyTorch also allows us to disable the learned affine parameters for IN, so we can simply multiply the IN output by the style features we generate for our AdaIN layer.
class AdaptiveInstanceNormalization(nn.Module):
def __init__(self, in_channels, w_dim):
super().__init__()
self.instance_norm = nn.InstanceNorm2d(in_channels)
self.style_sigma = WSLinear(w_dim, in_channels)
self.style_mu = WSLinear(w_dim, in_channels)
def forward(self, x, w):
x = self.instance_norm(x)
style_sigma = self.style_sigma(w).unsqueeze(2).unsqueeze(3)
style_mu = self.style_mu(w).unsqueeze(2).unsqueeze(3)
return style_sigma * x + style_mu
The last additional module we need to create is a way to add noise periodically into the network. Since we also require a parameter to scale the noise with, we’ll implement this as a separate module instead of in the .forward
method of the Generator itself.
class NoiseInput(nn.Module):
def __init__(self, channels):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x):
noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
return x + self.weight * noise
The final piece is to put together the Generator and Discriminator models. To build the Generator, let’s first create a GeneratorBlock
module to serve as a building block for the full network. The GeneratorBlock
will add the noise inputs as well as the adaptive instance normalization with respect to the style input
.
class GeneratorBlock(nn.Module):
def __init__(self, in_channels, out_channels, w_dim):
super(GeneratorBlock, self).__init__()
self.conv1 = WSConv2d(in_channels, out_channels)
self.conv2 = WSConv2d(out_channels, out_channels)
self.noise1 = NoiseInput(out_channels)
self.noise2 = NoiseInput(out_channels)
self.adain1 = AdaptiveInstanceNormalization(out_channels, w_dim)
self.adain2 = AdaptiveInstanceNormalization(out_channels, w_dim)
def forward(self, x, w):
x = self.adain1(F.leaky_relu(self.noise1(self.conv1(x)), 0.2), w)
x = self.adain2(F.leaky_relu(self.noise2(self.conv2(x)), 0.2), w)
return x
With the GeneratorBlock
out of the way, we can now build the full network.
class Generator(nn.Module):
def __init__(self, args, dim_mults = [1, 1, 1, 2, 4, 4]):
super(Generator, self).__init__()
self.args = args
self.map = NoiseMappingNetwork(args)
self.starting_constant = nn.Parameter(torch.ones((1, args.latent, 8, 8)))
self.init_block = InitGeneratorBlock(args.latent, args.w_latent)
hidden_dims = [int(args.latent / mult) for mult in dim_mults]
self.progressive_blocks = nn.ModuleList([
*[
GeneratorBlock(in_f, out_f, args.w_latent) for in_f, out_f in zip(hidden_dims[:-1], hidden_dims[1:])
],
])
self.out_blocks = nn.ModuleList([
*[
WSConv2d(out_f, args.channel_size, 1, 1, 0) for out_f in hidden_dims
],
])
One important thing to note here is the additional starting_constant
parameter. Since our new generator doesn’t directly take the latent
in as input, and neither does it use
as a feedforward input (just in the adaptive normalization layers), we need a placeholder input to actually pass through the network. This parameter will just be a vector of ones that the network will modify in accordance with our inputs.
We also need to separate pass this constant through some initialization layers before we go through the main network. Since the structure for this initialization is a little different than the GeneratorBlock
, we’ll do it in a separate InitGeneratorBlock
.
class InitGeneratorBlock(nn.Module):
def __init__(self, in_channels, w_dim):
super(InitGeneratorBlock, self).__init__()
self.conv = WSConv2d(in_channels, in_channels)
self.noise1 = NoiseInput(in_channels)
self.noise2 = NoiseInput(in_channels)
self.adain1 = AdaptiveInstanceNormalization(in_channels, w_dim)
self.adain2 = AdaptiveInstanceNormalization(in_channels, w_dim)
def forward(self, x, w):
x = self.adain1(F.leaky_relu(self.noise1(x), 0.2), w)
x = self.adain2(F.leaky_relu(self.noise2(self.conv(x)), 0.2), w)
return x
With that out of the way, we can write the forward
method for the Generator. This is very similar to the Generator for the ProgressiveGAN, so I’ll skip the main details for the sake of brevity.
def forward(self, z, p, alpha):
w = self.map(F.normalize(z, dim=1))
out = self.init_block(self.starting_constant, w)
if p == 0:
return self.out_blocks[0](out)
for i in range(p):
upsampled = F.interpolate(out, scale_factor=2, mode="bilinear")
out = self.progressive_blocks[i](upsampled, w)
final_upsampled = self.out_blocks[p-1](upsampled)
final_out = self.out_blocks[p](out)
return F.tanh(self.fade(final_upsampled, final_out, alpha))
That’s the Generator! The Discriminator is actually just going to be the same one we used for the ProgressiveGAN, so with that we’re free to move on to actually training our network!
Now that we have every module written, we can put together a quick training loop and dataset to generate some cool new images!
As always, all the experiments below use images from the Japanese Woodblock Print database to better emulate the challenge of handling real datasets in the wild.
I’ll also skip the steps of defining the dataset and dataloader since they’re usually pretty simple. If you’d like to see that code, you can check out the full repo here .
The training loop in this case is also actually pretty much the exact same as the ProGAN, so I’ll skip most of that as well. The only major change we’ll make is to start from a resolution of 8
instead of 4
. The authors found that this can sometimes work better for more detailed datasets. To do this, we’ll just change the resolutions that the networks use for their starting point, as well as the actual training loop, and we should be good.
One minor change we’ll need to make is to modify the learning rate for the mapping network. The authors found training to be more stable if the learning rate for the mapping network was two magnitudes less than the learning rate for the main network e.g. . To do so, we’ll have to manually construct the parameter lists for the generator when we create the optimizer.
map_params = g_net.map.parameters()
base_params = list(g_net.init_block.parameters())
base_params += list(g_net.progressive_blocks.parameters())
base_params += list(g_net.out_blocks.parameters())
base_params += [g_net.starting_constant]
g_optimizer = optim.Adam([
{'params': base_params},
{'params': map_params, 'lr': self.args.lr * .01}
], lr=self.args.lr, betas=(0.5, 0.999))
As per usual, I also included some code to track the outputs of the generator throughout training on the same fixed latent. This time, we’ll have different resolutions of outputs as well.
With that, we’re ready to train! In my experiments, I ran the training for 20
epochs at each resolution, scaling up from a resolution of 8
to 256
. The paper suggests using a learning rate of 1e-3
and a batch size of 16
to start, and I use a latent size and intermediate latent size of 512
.
Let’s see how it did! To start, I've created an animation showing how the StyleGAN increases resolution and improves over the fixed latent. You can see when the resolution increases because the image quality becomes worse at first. Click the animation again if you'd like to restart it.
Overall, it appears to be doing pretty well at the start! We can see how it starts with the lower level features and then slowly adds onto them. However, we can also see that as we get to higher resolutions the output quality decreases pretty drastically.
I’ve put some examples of a generated image from our StyleGAN below.
So . . . the images aren’t much better than what we got with the ProGAN. Arguably, they're worse 🙃.
I have a feeling this is an issue related to the dataset I’ve been using, so make sure you train on a more curated or consistent dataset when you’re running your own experiments. In the images we got, we can see similarities between some of the dataset images and the generated art above, but the StyleGAN seems to have trouble with the sheer complexity of the input data. The main reason this is likely happening for this particular dataset is image quality - the larger a network gets it generally becomes more important to have access to quality data.
However, it’s also quite likely that with more training, the outputs of the model would have improved. Like the ProGAN, the StyleGAN model takes an incredibly long time to train - I only trained for 100 epochs in total on an NVIDIA A10 GPU, while the original authors ended up training for 7 days in total with a distributed training setup over 8 Tesla V100 GPUs. If you’re trying this code out on your own, I’d highly recommend trying it out on a more curated dataset with samples that are more similar to each other, or simply allowing more time for the model to converge.
Since the StyleGAN builds on the ProGAN, there are some implementation details that we didn’t have to deal with directly, which made our job a bit easier. However, the StyleGAN generator itself is also quite complex to implement properly, so I’d recommend modularizing your code as much as possible. Furthermore, as we’ve seen, it’s possible that for your dataset the StyleGAN is simply not good enough at image generation. In this case, I’d recommend choosing datasets that make the task easier for your GANs, with consistent samples from similar angles / perspectives that are of the same general subject.
If you’ve made it this far, thank for reading through as always!
If you’d like to check out the full code for this series, you can visit my public repository .