Understanding Wasserstein GANs: A Mathematical Perspective

A paper review on Wasserstein GAN

Introduction

Generative Adversarial Networks (GANs) have revolutionized the field of generative modeling by enabling the creation of highly realistic images, videos, and other media. Despite their success, standard GANs frequently face significant training challenges, such as instability, mode collapse, and difficulties in achieving convergence. These issues often stem from the choice of divergence metrics used to measure the difference between real and generated data distributions.

In this blog post, we explore how Wasserstein GANs (WGANs) address these challenges by leveraging the Wasserstein distance, leading to more stable training and superior generative performance. We delve into the mathematical foundations of WGANs, discuss their improvements over standard GANs, and examine their practical implications.

Overcoming Challenges in Standard GANs

Standard GANs consist of two neural networks—the generator (G) and the discriminator (D)—engaged in a minimax game. The generator aims to create samples indistinguishable from real data, while the discriminator strives to differentiate between real and generated samples. The loss functions for the generator and discriminator can be expressed as:

\[\label{eq:gan_loss} \begin{aligned} \min_G \max_D & \, \mathbb{E}_{x \sim P_r} [\log D(x)] + \mathbb{E}_{\hat{x} \sim P_g} [\log (1 - D(\hat{x}) )], \end{aligned}\]

where \(P_r\) denotes the real data distribution, and \(P_g\) is the distribution induced by the generator over the data space.

Common Issues in Standard GAN Training

  1. Mode Collapse: The generator produces a limited diversity of samples, causing it to capture only a few modes of the real data distribution.
  2. Training Instability: The adversarial nature of the GAN training often leads to oscillations and lack of convergence.
  3. Vanishing Gradients: When the discriminator becomes too effective, the generator receives negligible gradient updates, hindering its improvement.

These problems largely arise from the nature of the Jensen-Shannon (JS) divergence used in standard GANs. When the supports of \(P_r\) and \(P_g\) do not overlap, the JS divergence becomes constant, leading to vanishing gradients.

The Wasserstein Distance

Wasserstein GANs propose an elegant solution by replacing the JS divergence with the Wasserstein distance (also known as the Earth Mover’s distance), which measures the cost of transforming one distribution into another. This metric remains meaningful even when the distributions have non-overlapping supports.

Mathematical Formulation

The Wasserstein distance between two probability distributions \(P_r\) and \(P_g\) is defined as:

\[\label{eq:wasserstein} W(P_r, P_g) = \inf_{\gamma \in \Pi(P_r, P_g)} \mathbb{E}_{(x, y) \sim \gamma}[\|x - y\|],\]

where \(\Pi(P_r, P_g)\) is the set of all joint distributions \(\gamma(x, y)\) whose marginals are \(P_r\) and \(P_g\), respectively. Intuitively, it represents the minimum amount of “work” needed to transform \(P_g\) into \(P_r\), considering the cost of moving mass from point \(x\) to \(y\).

Kantorovich-Rubinstein Duality

In practice, directly computing the Wasserstein distance is intractable. Instead, the problem can be reformulated using Kantorovich-Rubinstein duality:

\[\label{eq:kantorovich} W(P_r, P_g) = \sup_{\|f\|_L \leq 1} \mathbb{E}_{x \sim P_r}[f(x)] - \mathbb{E}_{x \sim P_g}[f(x)],\]

where the supremum is over all 1-Lipschitz functions \(f\).

In the context of WGANs, the discriminator \(D\) is used to approximate the optimal 1-Lipschitz function \(f\). Thus, the loss functions for the generator and discriminator can be written as:

\[\label{eq:wgan_loss} \begin{aligned} \max_D & \, \mathbb{E}_{x \sim P_r}[D(x)] - \mathbb{E}_{\hat{x} \sim P_g}[D(\hat{x})], \\ \min_G & \, \mathbb{E}_{\hat{x} \sim P_g}[D(\hat{x})]. \end{aligned}\]

Enforcing the 1-Lipschitz Constraint

A critical requirement for the discriminator in WGANs is the 1-Lipschitz constraint. In the original WGAN formulation, this constraint was enforced by clipping the weights of the discriminator to a narrow range. However, this clipping can lead to several practical problems, such as underfitting.

An improvement to this approach is the gradient penalty method. Here, the 1-Lipschitz constraint is enforced by adding a penalty to the loss function, ensuring that the gradient of the discriminator with respect to its input is close to 1:

\[\label{eq:gradient_penalty} \mathcal{L}_{\text{GP}} = \lambda \mathbb{E}_{\hat{x} \sim P_{\hat{x}}}\left[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2\right],\]

where \(P_{\hat{x}}\) is the distribution of points along straight lines between pairs of points sampled from \(P_r\) and \(P_g\).

Conclusion

Wasserstein GANs mark a significant advancement in the field of generative modeling, especially for image synthesis. By addressing the fundamental shortcomings of standard GANs with the imposition of the Wasserstein distance, WGANs achieve more stable training dynamics and generate higher quality samples.

The theoretical robustness rooted in optimal transport theory and practical improvements like gradient penalty make WGANs a critical tool in the arsenal of machine learning practitioners. As research in this area progresses, we can expect further refinements and hybrid models that continue to push the boundaries of what generative models can achieve.

In subsequent posts, we will explore other transformative advancements in AI and machine learning, providing a comprehensive view of the state-of-the-art technologies shaping our world.