Understanding Generative Adversarial Networks (GANS)
Updated: Dec 13, 2021
Introduced by Goodfellow et al. in NeurIPS (2014), generative adversarial networks (GANs) are an exciting recent innovation in machine learning. GANs are generative models: they create new data instances that resemble your training data. For example, GANs can create images that look like photographs of human faces, even though the faces don’t belong to any real person.
The following images of four photorealistic faces were created by a GAN created by NVIDIA:
GANs achieve this level of realism by pairing a generator, which learns to produce the target output, with a discriminator, which learns to distinguish true data from the output of the generator. The generator tries to fool the discriminator, and the discriminator tries to keep from being fooled.
A generative model could generate new photos of animals that look like real animals, while a discriminative model could tell a dog from a cat. GANs are just one kind of generative model.
More formally, given a set of data instances XX and a set of labels YY:
Generative models capture the joint probability P(X,Y)P(X,Y), or just P(X)P(X) if there are no labels.
Discriminative models capture the conditional probability P(Y∣X)P(Y∣X).
A generative model includes the distribution of the data itself, and tells you how likely a given example is. For example, models that predict the next word in a sequence are typically generative models (usually much simpler than GANs) because they can assign a probability to a sequence of words.
A discriminative model ignores the question of whether a given instance is likely, and just tells you how likely a label is to apply to the instance.
Neither kind of model has to return a number representing a probability. You can model the distribution of data by imitating that distribution.
For example, a discriminative classifier like a decision tree can label an instance without assigning a probability to that label. Such a classifier would still be a model because the distribution of all predicted labels would model the real distribution of labels in the data.
Similarly, a generative model can model a distribution by producing convincing “fake” data that looks like it’s drawn from that distribution.
Generative models tackle a more difficult task than analogous discriminative models. Generative models have to model more.
A generative model for images might capture correlations like “things that look like boats are probably going to appear near things that look like water” and “eyes are unlikely to appear on foreheads.” These are very complicated distributions.
In contrast, a discriminative model might learn the difference between “sailboat” or “not sailboat” by just looking for a few tell-tale patterns. It could ignore many of the correlations that the generative model must get right.
Discriminative models try to draw boundaries in the data space, while generative models try to model how data is placed throughout the space. For example, the following diagram shows discriminative and generative models of handwritten digits:
The discriminative model tries to tell the difference between handwritten 0’s and 1’s by drawing a line in the data space. If it gets the line right, it can distinguish 0’s from 1’s without ever having to model exactly where the instances are placed in the data space on either side of the line.
In contrast, the generative model tries to produce convincing 1’s and 0’s by generating digits that fall close to their real counterparts in the data space. It has to model the distribution throughout the data space.
GANs offer an effective way to train such rich models to resemble a real distribution.
A generative adversarial network (GAN) has two parts:
The generator learns to generate plausible data. The generated instances become negative training examples for the discriminator.
The discriminator learns to distinguish the generator’s fake data from real data. The discriminator penalizes the generator for producing implausible results.
When training begins, the generator produces obviously fake data, and the discriminator quickly learns to tell that it’s fake.
As training progresses, the generator gets closer to producing output that can fool the discriminator.
Finally, if generator training goes well, the discriminator gets worse at telling the difference between real and fake. It starts to classify fake data as real, and its accuracy decreases.
Both the generator and the discriminator are neural networks. The generator output is connected directly to the discriminator input. Through backpropagation, the discriminator’s classification provides a signal that the generator uses to update its weights.