Skip to content

Example of a GAN in PyTorch

# Pseudocode for GAN (PyTorch).

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Define layers for the generator

    def forward(self, noise):
        # Implement forward pass
        return generated_data

# Define Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Define layers for the discriminator

    def forward(self, data):
        # Implement forward pass
        return validity_score

# Initialize models
generator = Generator()
discriminator = Discriminator()

# Define loss function (Binary Cross-Entropy)
loss_function = nn.BCELoss()

# Define optimizers
lr = 0.0002
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

# Set up data loader
data_loader = DataLoader(
    datasets.MNIST('.', transform=transforms.ToTensor(), download=True),
    batch_size=64,
    shuffle=True
)

# Training Loop
epochs = 100
for epoch in range(epochs):
    for real_data, _ in data_loader:
        # Prepare real data and labels
        real_data = real_data.view(-1, 28 * 28)  # Flatten MNIST images
        real_labels = torch.ones(real_data.size(0), 1)
        fake_labels = torch.zeros(real_data.size(0), 1)


        # ---- Train Discriminator ----
        disc_optimizer.zero_grad()

        # Compute loss on real data
        real_pred = discriminator(real_data)
        real_loss = loss_function(real_pred, real_labels)

        # Generate fake data
        noise = torch.randn(real_data.size(0), 100)  # 100: Noise vector size
        fake_data = generator(noise)

        # Compute loss on fake data
        fake_pred = discriminator(fake_data.detach())  # Detach to avoid generator gradient update
        fake_loss = loss_function(fake_pred, fake_labels)

        # Backprop and update discriminator
        disc_loss = real_loss + fake_loss
        disc_loss.backward()
        disc_optimizer.step()

        # ---- Train Generator ----
        gen_optimizer.zero_grad()

        # Generate fake data and evaluate its validity
        fake_pred = discriminator(fake_data)
        gen_loss = loss_function(fake_pred, real_labels)  # Want discriminator to classify as real

        # Backprop and update generator
        gen_loss.backward()
        gen_optimizer.step()

    # Print progress
    print(f"Epoch [{epoch + 1}/{epochs}] | D Loss: {disc_loss.item():.4f} | G Loss: {gen_loss.item():.4f}")