Skip to content

GAN Trainer¤

Status: Supported runtime training surface

Module: artifex.generative_models.training.trainers.gan_trainer

Source: src/artifex/generative_models/training/trainers/gan_trainer.py

GANTrainer is a stateless step helper. It owns GAN-specific loss logic, but the caller owns generator/discriminator state, optimizers, and update cadence.

Quick Start¤

from flax import nnx
import jax
import optax

from artifex.generative_models.training.trainers import (
    GANTrainer,
    GANTrainingConfig,
)

generator = create_generator(rngs=nnx.Rngs(0))
discriminator = create_discriminator(rngs=nnx.Rngs(1))
g_optimizer = nnx.Optimizer(generator, optax.adam(1e-4, b1=0.5), wrt=nnx.Param)
d_optimizer = nnx.Optimizer(discriminator, optax.adam(1e-4, b1=0.5), wrt=nnx.Param)

config = GANTrainingConfig(loss_type="wasserstein", gp_weight=10.0)
trainer = GANTrainer(config)

key = jax.random.key(0)
key, d_key, g_key, z_key = jax.random.split(key, 4)
real = batch["image"]
z = jax.random.normal(z_key, (real.shape[0], latent_dim))

d_loss, d_metrics = trainer.discriminator_step(
    generator,
    discriminator,
    d_optimizer,
    real,
    z,
    d_key,
)
g_loss, g_metrics = trainer.generator_step(
    generator,
    discriminator,
    g_optimizer,
    z,
)

Configuration¤

artifex.generative_models.training.trainers.gan_trainer.GANTrainingConfig dataclass ¤

GANTrainingConfig(loss_type: Literal['vanilla', 'wasserstein', 'hinge', 'lsgan'] = 'vanilla', gp_weight: float = 10.0, gp_target: float = 1.0, r1_weight: float = 0.0, label_smoothing: float = 0.0)

Configuration for GAN-specific training.

Attributes:

Name Type Description
loss_type Literal['vanilla', 'wasserstein', 'hinge', 'lsgan']

GAN loss variant. - "vanilla": Standard GAN loss (BCE) - "wasserstein": Wasserstein distance (requires gradient penalty) - "hinge": Hinge loss (used in BigGAN, StyleGAN2) - "lsgan": Least squares GAN

gp_weight float

Gradient penalty weight (for WGAN-GP).

gp_target float

Target gradient norm (usually 1.0).

r1_weight float

R1 regularization weight.

label_smoothing float

Smooth real labels to [1-smoothing, 1].

loss_type class-attribute instance-attribute ¤

loss_type: Literal['vanilla', 'wasserstein', 'hinge', 'lsgan'] = 'vanilla'

gp_weight class-attribute instance-attribute ¤

gp_weight: float = 10.0

gp_target class-attribute instance-attribute ¤

gp_target: float = 1.0

r1_weight class-attribute instance-attribute ¤

r1_weight: float = 0.0

label_smoothing class-attribute instance-attribute ¤

label_smoothing: float = 0.0

Runtime-Active Fields¤

Parameter Default Description
loss_type "vanilla" GAN loss family
gp_weight 10.0 Gradient-penalty weight
gp_target 1.0 Target gradient norm for GP
r1_weight 0.0 R1 regularization weight
label_smoothing 0.0 One-sided smoothing for real labels

Update Cadence¤

Artifex does not hide GAN scheduling inside GANTrainingConfig. If you want multiple discriminator steps per generator step, make that explicit in the loop:

for _ in range(5):
    d_loss, d_metrics = trainer.discriminator_step(
        generator,
        discriminator,
        d_optimizer,
        real,
        z,
        d_key,
    )

g_loss, g_metrics = trainer.generator_step(
    generator,
    discriminator,
    g_optimizer,
    z,
)

Supported Loss Families¤

  • vanilla
  • wasserstein
  • hinge
  • lsgan

Use gp_weight for Wasserstein-style gradient penalties and r1_weight when you want R1 regularization on real data.