Skip to content

VAE Trainer¤

Status: Supported runtime training surface

Module: artifex.generative_models.training.trainers.vae_trainer

Source: src/artifex/generative_models/training/trainers/vae_trainer.py

The VAE Trainer provides specialized training utilities for Variational Autoencoders, including KL divergence annealing schedules, beta-VAE weighting for disentanglement, and free bits constraints to prevent posterior collapse.

Overview¤

Training VAEs requires balancing reconstruction quality against latent space regularization. The VAE Trainer handles this balance through:

  • KL Annealing: Gradual increase of KL weight to prevent posterior collapse
  • Beta-VAE Weighting: Control disentanglement vs reconstruction trade-off
  • Free Bits Constraint: Minimum KL per dimension to ensure information flow

Quick Start¤

from artifex.generative_models.training.trainers import (
    VAETrainer,
    VAETrainingConfig,
)
from flax import nnx
import optax

# Create model and optimizer
model = create_vae_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)

# Configure VAE-specific training
config = VAETrainingConfig(
    kl_annealing="cyclical",
    kl_warmup_steps=5000,
    beta=4.0,
    free_bits=0.5,
)

trainer = VAETrainer(config)

# Training loop
for step, batch in enumerate(train_loader):
    loss, metrics = trainer.train_step(model, optimizer, batch, step=step, loss_type="bce")
    if step % 100 == 0:
        print(f"Step {step}: loss={metrics['loss']:.4f}, "
              f"recon={metrics['reconstruction_loss']:.4f}, kl={metrics['kl_loss']:.4f}")

Configuration¤

artifex.generative_models.training.trainers.vae_trainer.VAETrainingConfig dataclass ¤

VAETrainingConfig(kl_annealing: Literal['none', 'linear', 'sigmoid', 'cyclical'] = 'linear', kl_warmup_steps: int = 10000, beta: float = 1.0, free_bits: float = 0.0, cyclical_period: int = 10000)

Configuration for VAE-specific training.

Attributes:

Name Type Description
kl_annealing Literal['none', 'linear', 'sigmoid', 'cyclical']

Type of KL annealing schedule. - "none": No annealing, use full beta from start - "linear": Linear warmup from 0 to beta - "sigmoid": Sigmoid-shaped warmup - "cyclical": Cyclical annealing with periodic resets

kl_warmup_steps int

Number of steps to reach full KL weight.

beta float

Final beta weight for KL term (beta-VAE). Higher values encourage disentanglement but may hurt reconstruction.

free_bits float

Minimum KL per latent dimension (0 = disabled). Prevents posterior collapse by ensuring minimum information flow.

cyclical_period int

Period for cyclical annealing (if used).

kl_annealing class-attribute instance-attribute ¤

kl_annealing: Literal['none', 'linear', 'sigmoid', 'cyclical'] = 'linear'

kl_warmup_steps class-attribute instance-attribute ¤

kl_warmup_steps: int = 10000

beta class-attribute instance-attribute ¤

beta: float = 1.0

free_bits class-attribute instance-attribute ¤

free_bits: float = 0.0

cyclical_period class-attribute instance-attribute ¤

cyclical_period: int = 10000

Configuration Options¤

Parameter Type Default Description
kl_annealing str "linear" KL weight schedule: "none", "linear", "sigmoid", "cyclical"
kl_warmup_steps int 10000 Steps to reach full KL weight
beta float 1.0 Final KL weight (beta-VAE parameter)
free_bits float 0.0 Minimum KL per latent dimension
cyclical_period int 10000 Period for cyclical annealing

KL Annealing Schedules¤

None (Constant)¤

No annealing - use full beta weight from the start:

config = VAETrainingConfig(kl_annealing="none", beta=1.0)
# KL weight = 1.0 at all steps

Linear Warmup¤

Linearly increase KL weight from 0 to beta:

config = VAETrainingConfig(
    kl_annealing="linear",
    kl_warmup_steps=10000,
    beta=1.0,
)
# KL weight = beta * min(1.0, step / warmup_steps)

Sigmoid Warmup¤

S-shaped warmup curve centered at half the warmup steps:

config = VAETrainingConfig(
    kl_annealing="sigmoid",
    kl_warmup_steps=10000,
    beta=1.0,
)

Cyclical Annealing¤

Periodically reset KL weight to encourage information flow:

config = VAETrainingConfig(
    kl_annealing="cyclical",
    cyclical_period=5000,
    beta=4.0,
)
# KL weight cycles: 0 -> beta -> 0 -> beta -> ...

Cyclical annealing helps prevent posterior collapse by periodically "reopening" information pathways.

Beta-VAE Training¤

Higher beta values encourage disentangled representations at the cost of reconstruction quality:

# Standard VAE (beta=1)
standard_config = VAETrainingConfig(beta=1.0)

# Beta-VAE for disentanglement (beta=4)
disentangled_config = VAETrainingConfig(beta=4.0)

# Strong regularization (beta=10)
strong_reg_config = VAETrainingConfig(beta=10.0)

Free Bits Constraint¤

Prevent posterior collapse by ensuring minimum KL per latent dimension:

config = VAETrainingConfig(
    free_bits=0.5,  # Minimum 0.5 nats per dimension
    beta=1.0,
)

The free bits constraint ensures each latent dimension carries at least the specified amount of information.

API Reference¤

artifex.generative_models.training.trainers.vae_trainer.VAETrainer ¤

VAETrainer(config: VAETrainingConfig | None = None)

VAE-specific trainer with KL annealing and beta-VAE support.

This trainer provides a JIT-compatible interface for training VAEs with: - KL annealing schedules (linear, sigmoid, cyclical) - Beta-VAE weighting for disentanglement - Free bits constraint to prevent posterior collapse

The train_step method takes model and optimizer as explicit arguments, allowing it to be wrapped with nnx.jit for performance.

The trainer computes the ELBO loss with configurable KL weighting

L = reconstruction_loss + beta * kl_weight(step) * kl_loss

Example (non-JIT):

from artifex.generative_models.training.trainers import (
    VAETrainer,
    VAETrainingConfig,
)

config = VAETrainingConfig(
    kl_annealing="cyclical",
    beta=4.0,
    free_bits=0.5,
)
trainer = VAETrainer(config)

# Create model and optimizer separately
model = VAEModel(config, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))

# Training loop
for step, batch in enumerate(data):
    loss, metrics = trainer.train_step(model, optimizer, batch, step=step)

Example (JIT-compiled):

trainer = VAETrainer(config)
jit_step = nnx.jit(trainer.train_step)

for step, batch in enumerate(data):
    loss, metrics = jit_step(model, optimizer, batch, step=step)

Note

The model is expected to return a canonical dict with reconstructed, mean, and log_var keys from its forward pass. The trainer handles ELBO loss computation and KL annealing.

Parameters:

Name Type Description Default
config VAETrainingConfig | None

VAE training configuration. Uses defaults if not provided.

None

config instance-attribute ¤

config = config or VAETrainingConfig()

get_kl_weight ¤

get_kl_weight(step: int | Array) -> Array

Compute KL weight based on annealing schedule.

This method is JIT-compatible - uses JAX operations instead of Python builtins.

Parameters:

Name Type Description Default
step int | Array

Current training step (can be traced array for JIT).

required

Returns:

Type Description
Array

KL weight multiplier (0.0 to beta).

apply_free_bits ¤

apply_free_bits(kl_per_dim: Array) -> Array

Apply free bits constraint to KL divergence.

Ensures minimum KL per latent dimension to prevent posterior collapse.

Parameters:

Name Type Description Default
kl_per_dim Array

KL divergence per latent dimension.

required

Returns:

Type Description
Array

KL divergence with free bits applied.

compute_kl_loss ¤

compute_kl_loss(mean: Array, logvar: Array) -> tuple[Array, Array]

Compute KL divergence from standard normal.

Parameters:

Name Type Description Default
mean Array

Latent mean, shape (batch, latent_dim).

required
logvar Array

Latent log-variance, shape (batch, latent_dim).

required

Returns:

Type Description
tuple[Array, Array]

Tuple of (total_kl_loss, kl_per_sample) where: - total_kl_loss: Scalar mean KL loss - kl_per_sample: KL loss per sample, shape (batch,)

compute_reconstruction_loss ¤

compute_reconstruction_loss(x: Array, recon_x: Array, loss_type: Literal['mse', 'mae', 'bce'] = 'mse') -> Array

Compute reconstruction loss.

Parameters:

Name Type Description Default
x Array

Original input, shape (batch, ...).

required
recon_x Array

Reconstructed output, shape (batch, ...).

required
loss_type Literal['mse', 'mae', 'bce']

Type of reconstruction loss.

'mse'

Returns:

Type Description
Array

Scalar reconstruction loss.

compute_loss ¤

compute_loss(model: Module, batch: dict[str, Any], step: int | Array, loss_type: Literal['mse', 'mae', 'bce'] = 'mse') -> tuple[Array, dict[str, Any]]

Compute VAE loss with KL annealing.

Parameters:

Name Type Description Default
model Module

VAE model to evaluate.

required
batch dict[str, Any]

Batch dictionary with "image" or "data" key.

required
step int | Array

Current training step for annealing.

required
loss_type Literal['mse', 'mae', 'bce']

Type of reconstruction loss.

'mse'

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (total_loss, metrics_dict).

train_step ¤

train_step(model: Module, optimizer: Optimizer, batch: dict[str, Any], step: int = 0, loss_type: Literal['mse', 'mae', 'bce'] = 'mse') -> tuple[Array, dict[str, Any]]

Execute a single training step.

This method can be wrapped with nnx.jit for performance: jit_step = nnx.jit(trainer.train_step) loss, metrics = jit_step(model, optimizer, batch, step=step)

Parameters:

Name Type Description Default
model Module

VAE model to train.

required
optimizer Optimizer

NNX optimizer for parameter updates.

required
batch dict[str, Any]

Batch dictionary with "image" or "data" key.

required
step int

Current training step for annealing.

0
loss_type Literal['mse', 'mae', 'bce']

Type of reconstruction loss.

'mse'

Returns:

Type Description
tuple[Array, dict[str, Any]]

Tuple of (loss, metrics_dict).

create_loss_fn ¤

create_loss_fn(loss_type: Literal['mse', 'bce'] = 'mse') -> Callable[[Module, dict[str, Any], Array, Array], tuple[Array, dict[str, Any]]]

Create loss function compatible with train_epoch_staged.

This enables DRY integration - VAE-specific training logic can be used with the staged training infrastructure.

Parameters:

Name Type Description Default
loss_type Literal['mse', 'bce']

Type of reconstruction loss.

'mse'

Returns:

Type Description
Callable[[Module, dict[str, Any], Array, Array], tuple[Array, dict[str, Any]]]

Function with signature: (model, batch, rng, step) -> (loss, metrics)

Callable[[Module, dict[str, Any], Array, Array], tuple[Array, dict[str, Any]]]

The step parameter is passed dynamically by train_epoch_staged,

Callable[[Module, dict[str, Any], Array, Array], tuple[Array, dict[str, Any]]]

enabling proper KL annealing inside JIT-compiled fori_loop.

Integration with Base Trainer¤

The VAE Trainer provides a step-aware create_loss_fn() closure for integration with the base Trainer's callbacks, checkpointing, and logging infrastructure:

from artifex.generative_models.training import Trainer
from artifex.generative_models.training.callbacks import CallbackList
from artifex.generative_models.training.trainers import VAETrainer, VAETrainingConfig
from artifex.generative_models.training.callbacks import (
    EarlyStopping,
    EarlyStoppingConfig,
    ModelCheckpoint,
    CheckpointConfig,
)

# Create VAE-specific trainer
vae_config = VAETrainingConfig(kl_annealing="cyclical", beta=4.0)
vae_trainer = VAETrainer(vae_config)

# Create an explicit step-aware objective closure
loss_fn = vae_trainer.create_loss_fn()

# Use with base Trainer for callbacks
callbacks = CallbackList([
    EarlyStopping(EarlyStoppingConfig(monitor="val_loss", patience=10)),
    ModelCheckpoint(CheckpointConfig(dirpath="checkpoints", monitor="val_loss")),
])

trainer = Trainer(
    model=model,
    training_config=training_config,
    loss_fn=loss_fn,
    callbacks=callbacks,
)

Model Requirements¤

The VAE Trainer expects models with the following interface:

class VAEModel(nnx.Module):
    def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
        """Forward pass returning (reconstruction, mean, logvar).

        Args:
            x: Input data, shape (batch, ...).

        Returns:
            Tuple of:
                - recon_x: Reconstructed data, shape (batch, ...)
                - mean: Latent mean, shape (batch, latent_dim)
                - logvar: Latent log-variance, shape (batch, latent_dim)
        """
        ...

Reconstruction Loss Types¤

The trainer supports MSE and BCE reconstruction losses:

# Mean Squared Error (default, for continuous data)
loss, metrics = trainer.train_step(model, optimizer, batch, step=100, loss_type="mse")

# Binary Cross-Entropy (for images normalized to [0, 1])
loss, metrics = trainer.train_step(model, optimizer, batch, step=100, loss_type="bce")

Training Metrics¤

The trainer returns detailed metrics for monitoring:

Metric Description
loss Total ELBO loss
reconstruction_loss Reconstruction loss
kl_loss KL divergence (unweighted)
kl_weight Current KL weight from annealing

References¤