Example of a GAN in PyTorch
# Pseudocode for GAN (PyTorch).
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define Generator Model
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# Define layers for the generator
def forward(self, noise):
# Implement forward pass
return generated_data
# Define Discriminator Model
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# Define layers for the discriminator
def forward(self, data):
# Implement forward pass
return validity_score
# Initialize models
generator = Generator()
discriminator = Discriminator()
# Define loss function (Binary Cross-Entropy)
loss_function = nn.BCELoss()
# Define optimizers
lr = 0.0002
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
# Set up data loader
data_loader = DataLoader(
datasets.MNIST('.', transform=transforms.ToTensor(), download=True),
batch_size=64,
shuffle=True
)
# Training Loop
epochs = 100
for epoch in range(epochs):
for real_data, _ in data_loader:
# Prepare real data and labels
real_data = real_data.view(-1, 28 * 28) # Flatten MNIST images
real_labels = torch.ones(real_data.size(0), 1)
fake_labels = torch.zeros(real_data.size(0), 1)
# ---- Train Discriminator ----
disc_optimizer.zero_grad()
# Compute loss on real data
real_pred = discriminator(real_data)
real_loss = loss_function(real_pred, real_labels)
# Generate fake data
noise = torch.randn(real_data.size(0), 100) # 100: Noise vector size
fake_data = generator(noise)
# Compute loss on fake data
fake_pred = discriminator(fake_data.detach()) # Detach to avoid generator gradient update
fake_loss = loss_function(fake_pred, fake_labels)
# Backprop and update discriminator
disc_loss = real_loss + fake_loss
disc_loss.backward()
disc_optimizer.step()
# ---- Train Generator ----
gen_optimizer.zero_grad()
# Generate fake data and evaluate its validity
fake_pred = discriminator(fake_data)
gen_loss = loss_function(fake_pred, real_labels) # Want discriminator to classify as real
# Backprop and update generator
gen_loss.backward()
gen_optimizer.step()
# Print progress
print(f"Epoch [{epoch + 1}/{epochs}] | D Loss: {disc_loss.item():.4f} | G Loss: {gen_loss.item():.4f}")