Energy Trainer¤
Status: Supported runtime training surface
Module: artifex.generative_models.training.trainers.energy_trainer
Source: src/artifex/generative_models/training/trainers/energy_trainer.py
EnergyTrainer implements Contrastive Divergence, Persistent Contrastive
Divergence, and score-matching training with Langevin negative-sample updates.
Quick Start¤
from flax import nnx
import jax
import optax
from artifex.generative_models.training.trainers import (
EnergyTrainer,
EnergyTrainingConfig,
)
model = create_energy_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
trainer = EnergyTrainer(
EnergyTrainingConfig(
training_method="pcd",
mcmc_steps=20,
step_size=0.01,
noise_scale=0.005,
)
)
key = jax.random.key(0)
loss, metrics = trainer.train_step(model, optimizer, batch, key)
Configuration¤
artifex.generative_models.training.trainers.energy_trainer.EnergyTrainingConfig
dataclass
¤
EnergyTrainingConfig(training_method: Literal['cd', 'pcd', 'score_matching'] = 'cd', mcmc_steps: int = 20, step_size: float = 0.01, noise_scale: float = 0.005, gradient_clipping: float = 1.0, replay_buffer_size: int = 10000, replay_buffer_init_prob: float = 0.95, energy_regularization: float = 0.0, gradient_penalty_weight: float = 0.0)
Configuration for energy-based model training.
Attributes:
| Name | Type | Description |
|---|---|---|
training_method |
Literal['cd', 'pcd', 'score_matching']
|
Training method for the energy model. - "cd": Contrastive Divergence (initialize chains from data) - "pcd": Persistent Contrastive Divergence (persistent chains) - "score_matching": Denoising score matching (gradient-based) |
mcmc_steps |
int
|
Number of MCMC steps for sampling negatives. |
step_size |
float
|
Step size for MCMC updates. |
noise_scale |
float
|
Scale of noise injection in Langevin dynamics. |
gradient_clipping |
float
|
Max gradient norm for MCMC updates. |
replay_buffer_size |
int
|
Size of replay buffer for PCD (0 = no buffer). |
replay_buffer_init_prob |
float
|
Probability of initializing from buffer vs noise. |
energy_regularization |
float
|
L2 regularization on energy values. |
gradient_penalty_weight |
float
|
Weight for gradient penalty regularization. |
training_method
class-attribute
instance-attribute
¤
Runtime-Active Fields¤
| Parameter | Default | Description |
|---|---|---|
training_method |
"cd" |
One of cd, pcd, or score_matching |
mcmc_steps |
20 |
Number of Langevin updates for negative samples |
step_size |
0.01 |
Langevin step size |
noise_scale |
0.005 |
Noise multiplier in Langevin updates |
gradient_clipping |
1.0 |
Clip norm applied to Langevin input gradients |
replay_buffer_size |
10000 |
Replay-buffer capacity for pcd |
replay_buffer_init_prob |
0.95 |
Probability of drawing PCD starts from the replay buffer |
energy_regularization |
0.0 |
Optional energy-value penalty |
gradient_penalty_weight |
0.0 |
Optional gradient penalty weight |
Langevin Sampling¤
The negative-sample path is always Langevin based:
The trainer does not expose alternate sampler families in its public config.