Skip to content

Energy Trainer¤

Status: Supported runtime training surface

Module: artifex.generative_models.training.trainers.energy_trainer

Source: src/artifex/generative_models/training/trainers/energy_trainer.py

EnergyTrainer implements Contrastive Divergence, Persistent Contrastive Divergence, and score-matching training with Langevin negative-sample updates.

Quick Start¤

from flax import nnx
import jax
import optax

from artifex.generative_models.training.trainers import (
    EnergyTrainer,
    EnergyTrainingConfig,
)

model = create_energy_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
trainer = EnergyTrainer(
    EnergyTrainingConfig(
        training_method="pcd",
        mcmc_steps=20,
        step_size=0.01,
        noise_scale=0.005,
    )
)

key = jax.random.key(0)
loss, metrics = trainer.train_step(model, optimizer, batch, key)

Configuration¤

artifex.generative_models.training.trainers.energy_trainer.EnergyTrainingConfig dataclass ¤

EnergyTrainingConfig(training_method: Literal['cd', 'pcd', 'score_matching'] = 'cd', mcmc_steps: int = 20, step_size: float = 0.01, noise_scale: float = 0.005, gradient_clipping: float = 1.0, replay_buffer_size: int = 10000, replay_buffer_init_prob: float = 0.95, energy_regularization: float = 0.0, gradient_penalty_weight: float = 0.0)

Configuration for energy-based model training.

Attributes:

Name Type Description
training_method Literal['cd', 'pcd', 'score_matching']

Training method for the energy model. - "cd": Contrastive Divergence (initialize chains from data) - "pcd": Persistent Contrastive Divergence (persistent chains) - "score_matching": Denoising score matching (gradient-based)

mcmc_steps int

Number of MCMC steps for sampling negatives.

step_size float

Step size for MCMC updates.

noise_scale float

Scale of noise injection in Langevin dynamics.

gradient_clipping float

Max gradient norm for MCMC updates.

replay_buffer_size int

Size of replay buffer for PCD (0 = no buffer).

replay_buffer_init_prob float

Probability of initializing from buffer vs noise.

energy_regularization float

L2 regularization on energy values.

gradient_penalty_weight float

Weight for gradient penalty regularization.

training_method class-attribute instance-attribute ¤

training_method: Literal['cd', 'pcd', 'score_matching'] = 'cd'

mcmc_steps class-attribute instance-attribute ¤

mcmc_steps: int = 20

step_size class-attribute instance-attribute ¤

step_size: float = 0.01

noise_scale class-attribute instance-attribute ¤

noise_scale: float = 0.005

gradient_clipping class-attribute instance-attribute ¤

gradient_clipping: float = 1.0

replay_buffer_size class-attribute instance-attribute ¤

replay_buffer_size: int = 10000

replay_buffer_init_prob class-attribute instance-attribute ¤

replay_buffer_init_prob: float = 0.95

energy_regularization class-attribute instance-attribute ¤

energy_regularization: float = 0.0

gradient_penalty_weight class-attribute instance-attribute ¤

gradient_penalty_weight: float = 0.0

Runtime-Active Fields¤

Parameter Default Description
training_method "cd" One of cd, pcd, or score_matching
mcmc_steps 20 Number of Langevin updates for negative samples
step_size 0.01 Langevin step size
noise_scale 0.005 Noise multiplier in Langevin updates
gradient_clipping 1.0 Clip norm applied to Langevin input gradients
replay_buffer_size 10000 Replay-buffer capacity for pcd
replay_buffer_init_prob 0.95 Probability of drawing PCD starts from the replay buffer
energy_regularization 0.0 Optional energy-value penalty
gradient_penalty_weight 0.0 Optional gradient penalty weight

Langevin Sampling¤

The negative-sample path is always Langevin based:

negatives = trainer.run_mcmc_chain(model, x_init, key, num_steps=20)

The trainer does not expose alternate sampler families in its public config.

Training Methods¤

Contrastive Divergence¤

EnergyTrainingConfig(training_method="cd", mcmc_steps=20)

Persistent Contrastive Divergence¤

EnergyTrainingConfig(
    training_method="pcd",
    mcmc_steps=20,
    replay_buffer_size=10000,
)

Score Matching¤

EnergyTrainingConfig(training_method="score_matching")