Loss Functions for Generative Models¤
Level: Intermediate | Runtime: ~30 seconds (CPU) | Format: Python + Jupyter
Prerequisites: Basic understanding of loss functions and JAX | Target Audience: Users learning Artifex's loss API
Overview¤
This example provides a complete guide to loss functions in Artifex, covering everything from simple functional losses to advanced composable loss systems. Learn how to use built-in losses, create custom compositions, and apply specialized losses for VAEs, GANs, and geometric models.
What You'll Learn¤
-
Functional Losses
Simple loss functions (MSE, MAE) with flexible reduction modes
-
Composable System
Combine weighted losses with component tracking
-
VAE Losses
Reconstruction + KL divergence for variational autoencoders
-
GAN Losses
Generator and discriminator losses (Standard, LS-GAN, Wasserstein)
-
Scheduled Losses
Time-varying loss weights for curriculum learning
-
Geometric Losses
Chamfer distance and mesh losses for 3D data
Files¤
This example is available in two formats:
- Python Script:
loss_examples.py - Jupyter Notebook:
loss_examples.ipynb
Quick Start¤
Run the Python Script¤
# Install Artifex if needed
pip install avitai-artifex
# Run the example
python examples/generative_models/loss_examples.py
Run the Jupyter Notebook¤
# Install Artifex if needed
pip install avitai-artifex
# Launch Jupyter
jupyter lab examples/generative_models/loss_examples.ipynb
Key Concepts¤
1. Functional Losses¤
Simple, stateless loss functions for common use cases:
from artifex.generative_models.core.losses import mse_loss, mae_loss
# Mean Squared Error
loss = mse_loss(predictions, targets, reduction="mean")
# Mean Absolute Error
loss = mae_loss(predictions, targets, reduction="sum")
Available Reductions:
"mean": Average over all elements (default)"sum": Sum all elements"none": Return per-element losses
2. Weighted Loss Terms¤
Apply fixed weights to loss components:
3. Explicit Multi-Term Objectives¤
Combine multiple loss functions:
reconstruction_loss = mse_loss(predictions, targets)
l1_penalty = mae_loss(predictions, targets)
total_loss = reconstruction_loss + 0.5 * l1_penalty
components = {
"reconstruction": reconstruction_loss,
"l1_penalty": l1_penalty,
}
4. VAE Losses¤
VAE loss combines reconstruction and KL divergence:
def vae_loss(reconstruction, targets, mean, logvar, beta=1.0):
# Reconstruction loss
recon_loss = mse_loss(reconstruction, targets)
# KL divergence (assuming standard normal prior)
kl_loss = -0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar))
kl_loss = kl_loss / targets.shape[0] # Normalize by batch size
# Total VAE loss
return recon_loss + beta * kl_loss
β Parameter:
β = 1.0: Standard VAEβ > 1.0: β-VAE (encourages disentanglement)β < 1.0: Less regularization, better reconstruction
5. GAN Losses¤
Artifex keeps generator and discriminator objectives explicit:
from artifex.generative_models.core.losses import (
least_squares_discriminator_loss,
least_squares_generator_loss,
)
# Generator loss (want discriminator to output 1 for fake)
g_loss = least_squares_generator_loss(fake_scores)
# Discriminator loss (real→1, fake→0)
d_loss = least_squares_discriminator_loss(real_scores, fake_scores)
Available GAN Loss Types:
"vanilla": Binary cross-entropy (original GAN)"lsgan": Least-squares GAN (more stable)"wgan": Wasserstein GAN (requires gradient penalty)
6. Scheduled Loss Terms¤
Time-varying loss weights for curriculum learning:
# Define schedule function
def warmup_schedule(step):
"""Linear warmup from 0 to 1 over 1000 steps."""
return jnp.minimum(1.0, step / 1000.0)
base_loss = perceptual_loss(...)
loss_value = warmup_schedule(step=500) * base_loss # weight = 0.5
7. Geometric Losses¤
Specialized losses for 3D data:
Chamfer Distance¤
Measures point cloud similarity:
from artifex.generative_models.core.losses import chamfer_distance
# Point clouds: (batch, num_points, 3)
pred_points = jax.random.normal(key, (4, 1000, 3))
target_points = jax.random.normal(key, (4, 1000, 3))
loss = chamfer_distance(pred_points, target_points)
Mesh Loss¤
Multi-component loss for mesh quality:
from artifex.generative_models.core.losses import MeshLoss
mesh_loss = MeshLoss(
vertex_weight=1.0, # Vertex position accuracy
normal_weight=0.1, # Surface normal consistency
edge_weight=0.1, # Edge length preservation
laplacian_weight=0.01 # Smoothness regularization
)
# Mesh format: (vertices, faces, normals)
pred_mesh = (vertices_pred, faces, normals_pred)
target_mesh = (vertices_target, faces, normals_target)
loss = mesh_loss(pred_mesh, target_mesh)
8. Perceptual Loss¤
Feature-based loss using pre-trained networks:
from artifex.generative_models.core.losses import PerceptualLoss
perceptual = PerceptualLoss(
content_weight=1.0,
style_weight=0.01
)
# Requires feature extraction from images
loss = perceptual(
pred_images=generated_images,
target_images=real_images,
features_pred=extracted_features_pred,
features_target=extracted_features_target
)
9. Total Variation Loss¤
Smoothness regularization for images:
from artifex.generative_models.core.losses import total_variation_loss
# Encourages spatial smoothness
tv_loss = total_variation_loss(generated_images)
# Often combined with other losses
total_loss = reconstruction_loss + 0.001 * tv_loss
Code Structure¤
The example demonstrates seven loss usage patterns:
- Functional Usage - Simple MSE and MAE losses
- Composable Loss - Weighted loss combination
- VAE Training - Reconstruction + KL divergence
- GAN Training - Generator and discriminator losses
- Scheduled Loss - Progressive loss weight ramping
- Geometric Losses - Chamfer distance and mesh losses
- Complete Training - Full training loop with losses
Features Demonstrated¤
- ✅ Functional losses with flexible reduction modes
- ✅ Weighted loss composition with component tracking
- ✅ VAE loss (reconstruction + KL divergence)
- ✅ GAN loss suites (standard, LS-GAN, Wasserstein)
- ✅ Scheduled losses for curriculum learning
- ✅ Geometric losses for 3D data (Chamfer, mesh)
- ✅ Perceptual loss with feature extraction
- ✅ Total variation loss for smoothness
- ✅ Integration with Flax NNX training loops
Experiments to Try¤
- Adjust Loss Weights
# Try different β values for VAE ELBO
reconstruction_loss = mse_loss(reconstructed, targets, reduction="batch_sum")
kl_loss = gaussian_kl_divergence(mean, log_var, reduction="batch_sum")
total_loss = reconstruction_loss + 4.0 * kl_loss # β = 4.0
- Compare GAN Loss Types
from artifex.generative_models.core.losses import (
least_squares_discriminator_loss,
least_squares_generator_loss,
vanilla_discriminator_loss,
vanilla_generator_loss,
)
# Vanilla GAN
g_loss = vanilla_generator_loss(fake_scores)
d_loss = vanilla_discriminator_loss(real_scores, fake_scores)
# LS-GAN (often more stable)
g_loss = least_squares_generator_loss(fake_scores)
d_loss = least_squares_discriminator_loss(real_scores, fake_scores)
- Custom Schedule Functions
# Exponential warmup
def exp_schedule(step):
return 1.0 - jnp.exp(-step / 1000.0)
# Cosine annealing
def cosine_schedule(step):
return 0.5 * (1 + jnp.cos(jnp.pi * step / total_steps))
- Geometric Loss Weights
# Adjust mesh loss components
mesh_loss = MeshLoss(
vertex_weight=2.0, # Emphasize position accuracy
normal_weight=0.5, # More weight on normals
edge_weight=0.1,
laplacian_weight=0.01
)
Next Steps¤
-
VAE Examples
Apply losses in VAE training
-
GAN Examples
Use GAN losses in training
-
Geometric Models
Apply geometric losses
-
Framework Features
Understand composable design
Troubleshooting¤
Shape Mismatch Errors¤
Symptom: ValueError about incompatible shapes
Solution: Ensure predictions and targets have the same shape
print(f"Predictions: {predictions.shape}")
print(f"Targets: {targets.shape}")
# Reshape if needed
predictions = predictions.reshape(targets.shape)
NaN in KL Divergence¤
Symptom: KL loss becomes NaN during VAE training
Cause: Numerical instability in exp(logvar) for large logvar
Solution: Clip logvar values
logvar = jnp.clip(logvar, -10.0, 10.0)
kl_loss = -0.5 * jnp.sum(1 + logvar - mean**2 - jnp.exp(logvar))
GAN Loss Not Converging¤
Symptom: Generator or discriminator loss diverges
Solution: Try LS-GAN loss instead of standard GAN
from artifex.generative_models.core.losses import (
least_squares_discriminator_loss,
least_squares_generator_loss,
)
# LS-GAN is often more stable
g_loss = least_squares_generator_loss(fake_scores)
d_loss = least_squares_discriminator_loss(real_scores, fake_scores)
Missing Loss Components¤
Symptom: Metrics dictionaries do not contain the component you want to log
Solution: Build the component dictionary explicitly alongside total_loss
Additional Resources¤
Documentation¤
- Loss Functions API Reference - Complete loss function documentation
Related Examples¤
- Framework Features Demo - Explicit loss composition
- VAE MNIST Tutorial - VAE loss in practice
- GAN MNIST Tutorial - GAN loss in practice
- Geometric Benchmark - Geometric losses
Papers¤
- VAE: Auto-Encoding Variational Bayes (Kingma & Welling, 2013)
- β-VAE: β-VAE: Learning Basic Visual Concepts (Higgins et al., 2017)
- LS-GAN: Least Squares GAN (Mao et al., 2017)
- Perceptual Loss: Perceptual Losses (Johnson et al., 2016)
- Chamfer Distance: Learning Representations and Generative Models for 3D Point Clouds