Energy-Based Models User Guide¤
Complete guide to building, training, and using Energy-Based Models with Artifex.
Overview¤
This guide covers practical usage of EBMs in Artifex, from basic setup to advanced techniques. You'll learn how to:
-
Configure EBMs
Set up energy functions and MCMC sampling parameters
-
Train Models
Train with persistent contrastive divergence and monitor stability
-
Generate Samples
Sample using Langevin dynamics and MCMC methods
-
Tune & Debug
Optimize hyperparameters and troubleshoot common issues
Quick Start¤
Basic EBM Example¤
Artifex provides factory functions for common use cases:
import jax.numpy as jnp
from flax import nnx
from artifex.generative_models.models.energy import create_mnist_ebm
# Initialize RNGs
rngs = nnx.Rngs(0)
# Create EBM optimized for MNIST-like data
model = create_mnist_ebm(rngs=rngs)
# Compute energy for a batch of images
batch = jnp.ones((4, 28, 28, 1))
output = model(batch)
print(f"Energy values: {output['energy']}")
print(f"Score shape: {output['score'].shape}")
# Generate samples using MCMC
model.eval() # Set to evaluation mode
samples = model.generate(n_samples=4, shape=(28, 28, 1))
print(f"Generated samples shape: {samples.shape}")
Configuration-Based Approach¤
For more control, use the config-based API:
from artifex.generative_models.core.configuration.energy_config import (
EBMConfig,
EnergyNetworkConfig,
MCMCConfig,
SampleBufferConfig,
)
from artifex.generative_models.models.energy import EBM
# Configure energy network
energy_config = EnergyNetworkConfig(
name="energy_net",
hidden_dims=(128, 256, 128),
activation="silu",
network_type="mlp",
)
# Configure MCMC sampling
mcmc_config = MCMCConfig(
name="langevin",
n_steps=60, # Number of MCMC steps
step_size=0.01,
noise_scale=0.005,
)
# Configure sample buffer (replay buffer)
buffer_config = SampleBufferConfig(
name="buffer",
capacity=8192, # Maximum samples to store
reinit_prob=0.05, # Probability to reinitialize samples
)
# Create EBM config
config = EBMConfig(
name="mnist_ebm",
input_dim=784, # Flattened MNIST
energy_network=energy_config,
mcmc=mcmc_config,
sample_buffer=buffer_config,
alpha=0.01, # Regularization
)
# Create model
rngs = nnx.Rngs(params=0, noise=1, sample=2)
model = EBM(config, rngs=rngs)
Creating EBM Models¤
1. Standard EBM (MLP Energy Function)¤
For tabular or low-dimensional data:
from artifex.generative_models.core.configuration.energy_config import (
EBMConfig,
EnergyNetworkConfig,
MCMCConfig,
SampleBufferConfig,
)
from artifex.generative_models.models.energy import EBM
# MLP energy function configuration
energy_config = EnergyNetworkConfig(
name="tabular_energy",
hidden_dims=(256, 256, 128),
activation="gelu",
dropout_rate=0.1,
network_type="mlp",
)
# MCMC sampling configuration
mcmc_config = MCMCConfig(
name="mcmc",
n_steps=60,
step_size=0.01,
noise_scale=0.005,
)
# Sample buffer configuration
buffer_config = SampleBufferConfig(
name="buffer",
capacity=4096,
)
# Create EBM config
config = EBMConfig(
name="tabular_ebm",
input_dim=784, # Flattened input
energy_network=energy_config,
mcmc=mcmc_config,
sample_buffer=buffer_config,
alpha=0.01,
)
rngs = nnx.Rngs(params=0, noise=1, sample=2)
model = EBM(config, rngs=rngs)
Key Parameters:
| Config Class | Parameter | Default | Description |
|---|---|---|---|
EnergyNetworkConfig |
network_type |
"mlp" | Energy function architecture ("mlp" or "cnn") |
EnergyNetworkConfig |
hidden_dims |
(128, 128) | Hidden layer dimensions |
EnergyNetworkConfig |
activation |
"gelu" | Activation function name |
MCMCConfig |
n_steps |
60 | Number of Langevin dynamics steps |
MCMCConfig |
step_size |
0.01 | Step size for gradient descent |
MCMCConfig |
noise_scale |
0.005 | Noise scale for exploration |
SampleBufferConfig |
capacity |
8192 | Maximum samples in replay buffer |
SampleBufferConfig |
reinit_prob |
0.05 | Probability to reinitialize from scratch |
EBMConfig |
alpha |
0.01 | Regularization strength |
2. CNN Energy Function (for Images)¤
For image data:
from artifex.generative_models.core.configuration.energy_config import (
DeepEBMConfig,
EnergyNetworkConfig,
MCMCConfig,
SampleBufferConfig,
)
from artifex.generative_models.models.energy import DeepEBM
# CNN energy network for images
energy_config = EnergyNetworkConfig(
name="image_energy",
hidden_dims=(64, 128, 256),
activation="silu",
network_type="cnn",
)
mcmc_config = MCMCConfig(
name="mcmc",
n_steps=100,
step_size=0.005,
noise_scale=0.001,
)
buffer_config = SampleBufferConfig(
name="buffer",
capacity=8192,
reinit_prob=0.05,
)
config = DeepEBMConfig(
name="image_ebm",
input_shape=(32, 32, 3), # CIFAR-10 dimensions (H, W, C)
energy_network=energy_config,
mcmc=mcmc_config,
sample_buffer=buffer_config,
alpha=0.001,
)
rngs = nnx.Rngs(params=0, noise=1, sample=2)
model = DeepEBM(config, rngs=rngs)
3. Deep EBM (Complex Data)¤
For complex datasets requiring deeper architectures:
from artifex.generative_models.core.configuration.energy_config import (
DeepEBMConfig,
EnergyNetworkConfig,
MCMCConfig,
SampleBufferConfig,
)
from artifex.generative_models.models.energy import DeepEBM
energy_config = EnergyNetworkConfig(
name="deep_energy",
hidden_dims=(32, 64, 128, 256),
activation="silu",
network_type="cnn",
use_residual=True,
)
mcmc_config = MCMCConfig(
name="mcmc",
n_steps=100,
step_size=0.005,
noise_scale=0.001,
)
buffer_config = SampleBufferConfig(
name="buffer",
capacity=8192,
)
config = DeepEBMConfig(
name="deep_ebm",
input_shape=(32, 32, 3),
energy_network=energy_config,
mcmc=mcmc_config,
sample_buffer=buffer_config,
alpha=0.001,
)
rngs = nnx.Rngs(params=0, noise=1, sample=2)
model = DeepEBM(config, rngs=rngs)
Deep EBM Features:
- Residual connections: Enable deeper networks (10+ layers)
- Group normalization: Keeps CNN energy blocks stable during sampling
- GroupNorm: Better than BatchNorm for MCMC sampling
Training EBMs¤
Basic Training Loop¤
import optax
from flax import nnx
from artifex.generative_models.training.trainers import (
EnergyTrainer,
EnergyTrainingConfig,
)
# Training configuration
train_config = EnergyTrainingConfig(
training_method="pcd", # Persistent Contrastive Divergence
mcmc_sampler="langevin",
mcmc_steps=20,
step_size=0.01,
noise_scale=0.005,
replay_buffer_size=10000,
replay_buffer_init_prob=0.95,
)
# Create trainer, optimizer, and RNG key
trainer = EnergyTrainer(train_config)
optimizer = nnx.Optimizer(model, optax.adam(1e-4))
rng = jax.random.key(42)
# Training loop
for epoch in range(100):
for batch in train_loader:
rng, step_rng = jax.random.split(rng)
loss, metrics = trainer.train_step(model, optimizer, batch, step_rng)
Training with Monitoring¤
Monitor key metrics during training:
def train_step_with_monitoring(model, batch):
"""Training step with detailed monitoring using model.train_step()."""
loss_dict = model.train_step(batch)
# Log metrics (keys from contrastive_divergence_loss)
print(f"Step metrics:")
print(f" Loss: {loss_dict['total_loss']:.4f}")
print(f" Real energy: {loss_dict['real_energy_mean']:.4f}")
print(f" Fake energy: {loss_dict['fake_energy_mean']:.4f}")
# Compute energy gap
energy_gap = float(loss_dict['fake_energy_mean'] - loss_dict['real_energy_mean'])
print(f" Energy gap: {energy_gap:.4f}")
# Check for issues
if energy_gap < 0:
print("WARNING: Negative energy gap - real data has higher energy!")
if abs(float(loss_dict['real_energy_mean'])) > 100:
print("WARNING: Energy explosion detected!")
return loss_dict
# Training loop
for epoch in range(num_epochs):
for batch in train_loader:
loss_dict = train_step_with_monitoring(model, batch)
Hyperparameter Guidelines¤
MCMC Sampling:
# Quick sampling (less accurate)
quick_config = {
"mcmc_steps": 20,
"mcmc_step_size": 0.02,
"mcmc_noise_scale": 0.01,
}
# Standard sampling (balanced)
standard_config = {
"mcmc_steps": 60,
"mcmc_step_size": 0.01,
"mcmc_noise_scale": 0.005,
}
# High-quality sampling (slower)
quality_config = {
"mcmc_steps": 200,
"mcmc_step_size": 0.005,
"mcmc_noise_scale": 0.001,
}
Learning Rates:
# EBMs typically need lower learning rates than supervised models
learning_rates = {
"small_model": 1e-4,
"medium_model": 5e-5,
"large_model": 1e-5,
}
Generating Samples¤
Sampling from the Model¤
# Generate samples using Langevin dynamics MCMC
n_samples = 16
model.eval() # Set to evaluation mode for sampling
samples = model.generate(
n_samples=n_samples,
shape=(784,), # Shape of each sample (must match input_dim for MLP)
n_steps=100, # More steps = better quality
step_size=0.01,
noise_scale=0.005,
)
print(f"Generated samples shape: {samples.shape}")
Sampling with Different Configurations¤
# Quick sampling (fewer steps, larger step size)
quick_samples = model.generate(
n_samples=16,
shape=(784,),
n_steps=30,
step_size=0.02,
noise_scale=0.01,
)
# High-quality sampling (more steps, smaller step size)
hq_samples = model.generate(
n_samples=16,
shape=(784,),
n_steps=200,
step_size=0.005,
noise_scale=0.001,
)
# Sample from the replay buffer (returns buffered MCMC samples)
buffer_samples = model.sample_from_buffer(n_samples=16)
Using the EnergyTrainer for Generation¤
The EnergyTrainer also provides a generate_samples method for MCMC-based generation:
from artifex.generative_models.training.trainers import (
EnergyTrainer,
EnergyTrainingConfig,
)
trainer = EnergyTrainer(EnergyTrainingConfig(mcmc_steps=20, step_size=0.01))
rng = jax.random.key(0)
# Generate samples via the trainer's MCMC chain
samples = trainer.generate_samples(
model=model,
batch_size=16,
key=rng,
shape=(784,), # Shape of each sample
num_steps=200, # Defaults to 10x config mcmc_steps if None
)
Advanced Techniques¤
1. Sample Buffer Management¤
The sample buffer is critical for stable training:
# Access buffer statistics
buffer_size = model.sample_buffer.num_samples
print(f"Buffer contains {buffer_size} retained samples")
# Manually populate buffer by running train steps
# The train_step method automatically updates the sample buffer
for batch in train_loader:
loss_dict = model.train_step(batch)
# Samples are automatically added to the buffer during training
# Clear buffer (for reinitialization)
model.sample_buffer.clear()
2. Energy Landscape Visualization¤
Visualize the energy landscape:
import matplotlib.pyplot as plt
def visualize_energy_landscape(model, data_range=(-3, 3), resolution=100):
"""Visualize 2D energy landscape."""
x = jnp.linspace(data_range[0], data_range[1], resolution)
y = jnp.linspace(data_range[0], data_range[1], resolution)
X, Y = jnp.meshgrid(x, y)
# Compute energy for each point
points = jnp.stack([X.ravel(), Y.ravel()], axis=1)
energies = model.energy(points)
energies = energies.reshape(resolution, resolution)
# Plot
plt.figure(figsize=(10, 8))
plt.contourf(X, Y, energies, levels=50, cmap='viridis')
plt.colorbar(label='Energy')
plt.title('Energy Landscape')
plt.xlabel('x1')
plt.ylabel('x2')
plt.show()
# For 2D data
visualize_energy_landscape(model)
3. Annealed Importance Sampling¤
For better sampling quality:
def annealed_sampling(model, n_samples, sample_shape, n_steps=1000, rng_key=None):
"""Annealed importance sampling for high-quality samples."""
if rng_key is None:
rng_key = jax.random.key(0)
# Start with high temperature
temperatures = jnp.linspace(10.0, 1.0, n_steps)
# Initialize samples
rng_key, init_key = jax.random.split(rng_key)
samples = jax.random.normal(init_key, (n_samples, *sample_shape))
for i, temp in enumerate(temperatures):
rng_key, step_key = jax.random.split(rng_key)
# Compute energy gradient
energy_grad = jax.grad(lambda x: jnp.sum(model.energy(x)))(samples)
# Langevin step with temperature
step_size = 0.01 * temp
noise_scale = jnp.sqrt(2 * step_size * temp)
samples = samples - step_size * energy_grad
samples = samples + noise_scale * jax.random.normal(
step_key, samples.shape
)
return samples
# Use annealed sampling (pass sample shape explicitly)
high_quality_samples = annealed_sampling(
model, n_samples=16, sample_shape=(784,), rng_key=jax.random.key(0)
)
Troubleshooting¤
Common Issues and Solutions¤
-
Energy Explosion
Symptoms: Energy values grow unbounded, NaN losses
Solutions: - Reduce learning rate (try 1e-5) - Add/increase regularization (alpha=0.01 to 0.1) - Use stronger regularization or smaller learning rates - Clip gradients:
max_grad_norm=1.0 -
Poor Sample Quality
Symptoms: Samples look like noise or blurry
Solutions: - Increase MCMC steps (60 → 100+) - Better step size tuning - Larger buffer capacity - Deeper energy function
-
Mode Collapse
Symptoms: All samples look similar
Solutions: - Increase buffer reinit probability - Use data augmentation - Longer MCMC chains - Larger buffer
-
Training Instability
Symptoms: Oscillating losses, sudden divergence
Solutions: - Lower learning rate - Use persistent buffer - Add gradient clipping - Monitor energy gap
Debugging Checklist¤
def diagnose_ebm(model, batch, sample_shape):
"""Diagnostic checks for EBM training.
Args:
model: Trained EBM or DeepEBM instance
batch: A batch dict with 'image' or 'data' key
sample_shape: Shape of each sample (e.g. (784,) or (28, 28, 1))
"""
# Extract data from batch
real_data = batch.get('image', batch.get('data'))
# 1. Check energy values
real_energy = model.energy(real_data).mean()
print(f"Real data energy: {real_energy:.3f}")
# Generate samples
model.eval()
fake_samples = model.generate(n_samples=16, shape=sample_shape, n_steps=100)
fake_energy = model.energy(fake_samples).mean()
print(f"Generated samples energy: {fake_energy:.3f}")
# Energy gap should be positive
gap = fake_energy - real_energy
print(f"Energy gap: {gap:.3f}")
# 2. Check MCMC convergence
rng_key = jax.random.key(0)
init_samples = jax.random.uniform(rng_key, (16, *sample_shape), minval=-1.0, maxval=1.0)
init_energy = model.energy(init_samples).mean()
# Use generate() to run MCMC from random init
final_samples = model.generate(
n_samples=16,
shape=sample_shape,
n_steps=100,
step_size=0.01,
noise_scale=0.005,
)
final_energy = model.energy(final_samples).mean()
energy_decrease = init_energy - final_energy
print(f"MCMC energy decrease: {energy_decrease:.3f}")
# 3. Check buffer health
buffer_size = model.sample_buffer.num_samples
print(f"Buffer samples: {buffer_size}/{model.sample_buffer.capacity}")
# 4. Check sample validity
sample_min, sample_max = float(fake_samples.min()), float(fake_samples.max())
print(f"Sample range: [{sample_min:.3f}, {sample_max:.3f}]")
return {
"real_energy": real_energy,
"fake_energy": fake_energy,
"energy_gap": gap,
"mcmc_decrease": energy_decrease,
"buffer_usage": buffer_size / model.sample_buffer.capacity,
}
# Run diagnostics
diagnostics = diagnose_ebm(model, batch, sample_shape=(784,))
Best Practices¤
1. Start Simple¤
# Begin with a small model and simple data using the factory function
from artifex.generative_models.models.energy import create_simple_ebm
rngs = nnx.Rngs(params=0, noise=1, sample=2)
model = create_simple_ebm(
input_dim=2, # 2D for visualization
rngs=rngs,
hidden_dims=(64, 64),
activation="relu",
mcmc_steps=30,
step_size=0.02,
sample_buffer_capacity=1024,
)
2. Gradually Increase Complexity¤
from artifex.generative_models.models.energy import create_mnist_ebm, create_cifar_ebm
# Once stable, increase capacity (MNIST-like data)
medium_model = create_mnist_ebm(
rngs=rngs,
hidden_dims=(128, 256),
mcmc_steps=60,
sample_buffer_capacity=4096,
)
# For complex data (CIFAR-like images with residual blocks)
complex_model = create_cifar_ebm(
rngs=rngs,
hidden_dims=(64, 128, 256, 512),
use_residual=True,
mcmc_steps=100,
sample_buffer_capacity=8192,
)
3. Monitor Training Carefully¤
# Log detailed metrics
def detailed_training_step(model, batch, step, sample_shape):
loss_dict = model.train_step(batch)
if step % 100 == 0:
# Detailed logging (keys from contrastive_divergence_loss)
print(f"\nStep {step}:")
print(f" Loss: {loss_dict['total_loss']:.4f}")
print(f" Real energy: {loss_dict['real_energy_mean']:.4f}")
print(f" Fake energy: {loss_dict['fake_energy_mean']:.4f}")
gap = float(loss_dict['fake_energy_mean'] - loss_dict['real_energy_mean'])
print(f" Gap: {gap:.4f}")
# Generate samples for visual inspection
if step % 1000 == 0:
model.eval()
samples = model.generate(n_samples=64, shape=sample_shape)
model.train()
visualize_samples(samples, f"step_{step}.png")
return loss_dict
4. Use Proper Preprocessing¤
def preprocess_for_ebm(images):
"""Proper preprocessing for image EBMs."""
# Normalize to [-1, 1]
images = (images - 127.5) / 127.5
# Add small noise during training
if training:
noise = jax.random.normal(rng_key, images.shape) * 0.005
images = images + noise
images = jnp.clip(images, -1.0, 1.0)
return images
Performance Optimization¤
GPU Acceleration¤
# EBMs benefit significantly from GPU
from artifex.generative_models.core.device_manager import DeviceManager
device_manager = DeviceManager()
device = device_manager.get_device()
print(f"Using device: {device}")
# Move data to GPU
batch_gpu = jax.device_put(batch, device)
Batch Size Tuning¤
# Larger batches = more stable gradients
# But: limited by GPU memory
batch_sizes = {
"small_model": 256,
"medium_model": 128,
"large_model": 64,
}
JIT Compilation¤
# Use nnx.jit with the EnergyTrainer for JIT-compiled training
from artifex.generative_models.training.trainers import EnergyTrainer, EnergyTrainingConfig
trainer = EnergyTrainer(EnergyTrainingConfig(training_method="pcd"))
jit_step = nnx.jit(trainer.train_step)
# Much faster after first call
rng = jax.random.key(0)
rng, step_rng = jax.random.split(rng)
loss, metrics = jit_step(model, optimizer, batch, step_rng)
Example: Complete MNIST Training¤
import jax
import jax.numpy as jnp
from flax import nnx
from artifex.generative_models.models.energy import create_mnist_ebm
# Create model using factory function
rngs = nnx.Rngs(params=0, noise=1, sample=2)
model = create_mnist_ebm(
rngs=rngs,
hidden_dims=(128, 256, 512),
activation="silu",
mcmc_steps=60,
step_size=0.01,
noise_scale=0.005,
sample_buffer_capacity=8192,
alpha=0.01,
)
# Preprocessing: normalize images to [-1, 1]
def preprocess(images):
images = jnp.array(images, dtype=jnp.float32) / 255.0
images = (images - 0.5) / 0.5 # Normalize to [-1, 1]
return {"image": images}
# Training loop using model.train_step()
num_epochs = 50
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
for step, batch in enumerate(train_loader):
batch = preprocess(batch)
loss_dict = model.train_step(batch)
if step % 100 == 0:
gap = float(loss_dict['fake_energy_mean'] - loss_dict['real_energy_mean'])
print(f" Step {step}: Loss={loss_dict['total_loss']:.4f}, "
f"Gap={gap:.4f}")
# Generate samples
if (epoch + 1) % 10 == 0:
model.eval()
samples = model.generate(n_samples=64, shape=(28, 28, 1))
model.train()
save_image_grid(samples, f"epoch_{epoch+1}.png")
print("Training complete!")
Further Reading¤
- EBM Explained - Theoretical foundations
- EBM API Reference - Complete API documentation
- Training Guide - General training workflows
- Examples - More EBM examples
Summary¤
Key Takeaways:
- EBMs learn by assigning low energy to data, high energy to non-data
- Persistent Contrastive Divergence (PCD) with MCMC sampling is the standard training method
- Sample buffer management is critical for stable training
- Monitor energy gap: fake_energy should be > real_energy
- Start simple, increase complexity gradually
- Use regularization, gradient clipping, and careful monitoring for stability
Recommended Workflow:
- Start with simple 2D data to verify training works
- Use MLP energy for tabular, CNN for images
- Monitor energy gap and buffer health
- Tune MCMC steps and step size for your data
- Use DeepEBM for complex distributions
- Visualize samples frequently during training
For theoretical understanding, see the EBM Explained guide.