Skip to content

Loss Functions¤

Artifex exposes a curated public surface of functional loss primitives for generative models. Generic shared building blocks use CalibraX where that surface already exists, and Artifex keeps local implementations only for Artifex-specific needs or gaps that are not yet available in CalibraX.

Overview¤

  • Functional Primitives


    Narrow public building blocks for model families, trainers, and modalities

  • Explicit Composition


    Compose objectives directly in JAX instead of through wrapper frameworks

  • CalibraX-Backed Shared Metrics


    Use CalibraX-backed regression and divergence primitives where they already exist

  • JAX-Optimized


    Fully JIT-compatible and vectorized implementations

  • Numerically Stable


    Careful handling of edge cases and numerical precision

Loss Categories¤

graph TD
    A[Loss Functions] --> B[Reconstruction Losses]
    A --> C[Adversarial Losses]
    A --> D[Divergence Losses]
    A --> E[Perceptual Losses]
    A --> F[Composition Pattern]

    B --> B1[MSE]
    B --> B2[MAE]
    B --> B3[Huber]
    B --> B4[Charbonnier]

    C --> C1[Vanilla GAN]
    C --> C2[LSGAN]
    C --> C3[WGAN]
    C --> C4[Hinge]

    D --> D1[KL Divergence]
    D --> D2[JS Divergence]
    D --> D3[Wasserstein]
    D --> D4[MMD]

    E --> E1[Feature Reconstruction]
    E --> E2[Style Loss]
    E --> E3[Contextual Loss]

    F --> F1[Direct weighted sums]
    F --> F2[Explicit component dicts]
    F --> F3[Trainer-owned schedules]

    style A fill:#e1f5ff
    style B fill:#b3e5fc
    style C fill:#b3e5fc
    style D fill:#b3e5fc
    style E fill:#b3e5fc
    style F fill:#81d4fa

The public rule is simple:

  • use shared primitives directly
  • prefer CalibraX-backed generic metrics and divergences when available
  • keep composition explicit inside trainers, family-local objective helpers, or modality-specific glue

Reconstruction Losses¤

Reconstruction losses compare predictions directly with targets, typically used in autoencoders and regression tasks.

Location: src/artifex/generative_models/core/losses/reconstruction.py

mse_loss (L2 Loss)¤

Mean Squared Error - penalizes squared differences between predictions and targets.

from artifex.generative_models.core.losses.reconstruction import mse_loss
import jax.numpy as jnp

predictions = jnp.array([1.0, 2.0, 3.0])
targets = jnp.array([1.1, 1.9, 3.2])

loss = mse_loss(predictions, targets, reduction="mean")
print(f"MSE Loss: {loss}")  # 0.0233...

Parameters:

Parameter Type Default Description
predictions jax.Array Required Model predictions
targets jax.Array Required Ground truth values
reduction str "mean" Reduction: "none", "mean", "sum"
weights jax.Array \| None None Optional per-element weights
axis int \| tuple \| None None Axis for reduction

Use Cases:

  • VAE reconstruction loss
  • Image regression
  • General reconstruction tasks

mae_loss (L1 Loss)¤

Mean Absolute Error - more robust to outliers than MSE.

from artifex.generative_models.core.losses.reconstruction import mae_loss

loss = mae_loss(predictions, targets, reduction="mean")
print(f"MAE Loss: {loss}")  # 0.133...

Use Cases:

  • Robust regression
  • Outlier-heavy datasets
  • Image reconstruction with noise

huber_loss¤

Smooth combination of L1 and L2 losses - quadratic for small errors, linear for large errors.

from artifex.generative_models.core.losses.reconstruction import huber_loss

# Delta controls the transition point
loss = huber_loss(predictions, targets, delta=1.0, reduction="mean")

Parameters:

Parameter Type Default Description
delta float 1.0 Threshold between quadratic and linear regions

Use Cases:

  • Robust regression with outliers
  • Reinforcement learning (value functions)
  • Object detection

charbonnier_loss¤

Differentiable variant of L1 loss with smoother gradients.

from artifex.generative_models.core.losses.reconstruction import charbonnier_loss

loss = charbonnier_loss(
    predictions,
    targets,
    epsilon=1e-3,  # Smoothing constant
    alpha=1.0,     # Exponent
    reduction="mean"
)

Use Cases:

  • Optical flow estimation
  • Image super-resolution
  • Smooth optimization landscapes

psnr_loss¤

Peak Signal-to-Noise Ratio expressed as a loss (negative PSNR).

from artifex.generative_models.core.losses.reconstruction import psnr_loss

# For normalized images (0-1)
loss = psnr_loss(pred_images, target_images, max_value=1.0)

# For images in range [0, 255]
loss = psnr_loss(pred_images, target_images, max_value=255.0)

Use Cases:

  • Image quality assessment
  • Super-resolution evaluation
  • Compression evaluation

Adversarial Losses¤

Adversarial losses for training GANs and adversarial networks.

Location: src/artifex/generative_models/core/losses/adversarial.py

Vanilla GAN Losses¤

Original GAN formulation with binary cross-entropy.

from artifex.generative_models.core.losses.adversarial import (
    vanilla_generator_loss,
    vanilla_discriminator_loss
)

# Generator loss: -log(D(G(z)))
g_loss = vanilla_generator_loss(fake_scores)

# Discriminator loss: -log(D(x)) - log(1 - D(G(z)))
d_loss = vanilla_discriminator_loss(real_scores, fake_scores)

Pros: Simple, well-studied Cons: Vanishing gradients, mode collapse


LSGAN Losses¤

Least Squares GAN - uses mean squared error instead of cross-entropy.

from artifex.generative_models.core.losses.adversarial import (
    least_squares_generator_loss,
    least_squares_discriminator_loss
)

# Generator: minimize (D(G(z)) - 1)^2
g_loss = least_squares_generator_loss(
    fake_scores,
    target_real=1.0
)

# Discriminator: minimize (D(x) - 1)^2 + (D(G(z)) - 0)^2
d_loss = least_squares_discriminator_loss(
    real_scores,
    fake_scores,
    target_real=1.0,
    target_fake=0.0
)

Pros: More stable training, better gradients Cons: May require more careful tuning


Wasserstein GAN Losses¤

Wasserstein distance-based losses with better convergence properties.

from artifex.generative_models.core.losses.adversarial import (
    wasserstein_generator_loss,
    wasserstein_discriminator_loss
)

# Generator: minimize -D(G(z))
g_loss = wasserstein_generator_loss(fake_scores)

# Critic: minimize D(G(z)) - D(x)
d_loss = wasserstein_discriminator_loss(real_scores, fake_scores)

# Note: Requires gradient penalty or weight clipping

Pros: Stable training, meaningful loss curves Cons: Requires gradient penalty or weight clipping


Hinge Losses¤

Hinge loss formulation used in spectral normalization GANs.

from artifex.generative_models.core.losses.adversarial import (
    hinge_generator_loss,
    hinge_discriminator_loss
)

# Generator: -D(G(z))
g_loss = hinge_generator_loss(fake_scores)

# Discriminator: max(0, 1 - D(x)) + max(0, 1 + D(G(z)))
d_loss = hinge_discriminator_loss(real_scores, fake_scores)

Pros: Stable, works well with spectral normalization Cons: May need careful architecture design


Divergence Losses¤

Statistical divergence measures between probability distributions.

Location: src/artifex/generative_models/core/losses/divergence.py

kl_divergence¤

Kullback-Leibler divergence - measures information loss.

from artifex.generative_models.core.losses.divergence import kl_divergence
import distrax

# With distribution objects
p = distrax.Normal(loc=0.0, scale=1.0)
q = distrax.Normal(loc=0.5, scale=1.5)
kl = kl_divergence(p, q)

# With probability arrays
p_probs = jnp.array([0.2, 0.5, 0.3])
q_probs = jnp.array([0.3, 0.4, 0.3])
kl = kl_divergence(p_probs, q_probs, reduction="sum")

Formula: KL(P||Q) = Σ P(x) log(P(x) / Q(x))

Use Cases:

  • VAE latent regularization
  • Distribution matching
  • Information theory applications

js_divergence¤

Jensen-Shannon divergence - symmetric variant of KL divergence.

from artifex.generative_models.core.losses.divergence import js_divergence

p = jnp.array([0.2, 0.5, 0.3])
q = jnp.array([0.1, 0.7, 0.2])
js = js_divergence(p, q)

Formula: JS(P||Q) = 0.5 (KL(P||M) + KL(Q||M)) where M = 0.5 (P + Q)

Properties:

  • Symmetric: JS(P||Q) = JS(Q||P)
  • Bounded: 0 ≤ JS ≤ log(2)
  • Metric (satisfies triangle inequality)

wasserstein_distance¤

Earth Mover's Distance - optimal transport-based metric.

from artifex.generative_models.core.losses.divergence import wasserstein_distance

# 1D samples
p_samples = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
q_samples = jnp.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]])

# W1 distance
w1 = wasserstein_distance(p_samples, q_samples, p=1, axis=1)

# W2 distance
w2 = wasserstein_distance(p_samples, q_samples, p=2, axis=1)

Use Cases:

  • GAN training (WGAN)
  • Distribution comparison
  • Robust statistics

maximum_mean_discrepancy¤

Kernel-based distribution distance measure.

from artifex.generative_models.core.losses.divergence import maximum_mean_discrepancy
import jax

# Generate samples
key = jax.random.key(0)
pred_samples = jax.random.normal(key, (2, 100, 5))
target_samples = jax.random.normal(key, (2, 100, 5))

# Compute MMD with RBF kernel
mmd = maximum_mean_discrepancy(
    pred_samples,
    target_samples,
    kernel_type="rbf",
    kernel_bandwidth=1.0
)

Kernel Types:

  • "rbf": Radial Basis Function (Gaussian)
  • "linear": Linear kernel
  • "polynomial": Polynomial kernel

Use Cases:

  • Two-sample testing
  • Domain adaptation
  • Generative model evaluation

energy_distance¤

Metric between probability distributions based on Euclidean distance.

from artifex.generative_models.core.losses.divergence import energy_distance

energy = energy_distance(
    pred_samples,
    target_samples,
    beta=1.0  # Power parameter (0 < beta <= 2)
)

Properties:

  • Metric (satisfies triangle inequality)
  • Generalizes Euclidean distance
  • Computationally efficient

Perceptual Losses¤

Feature-based losses using pre-trained networks.

Location: src/artifex/generative_models/core/losses/perceptual.py

feature_reconstruction_loss¤

Compares features extracted from intermediate layers.

from artifex.generative_models.core.losses.perceptual import feature_reconstruction_loss

# Dictionary of features (e.g., from VGG)
features_real = {
    "conv1": jnp.ones((2, 64, 64, 64)),
    "conv2": jnp.ones((2, 32, 32, 128)),
    "conv3": jnp.ones((2, 16, 16, 256)),
}

features_fake = {
    "conv1": jnp.zeros((2, 64, 64, 64)),
    "conv2": jnp.zeros((2, 32, 32, 128)),
    "conv3": jnp.zeros((2, 16, 16, 256)),
}

# Weighted feature loss
loss = feature_reconstruction_loss(
    features_real,
    features_fake,
    weights={"conv1": 1.0, "conv2": 0.5, "conv3": 0.25}
)

Use Cases:

  • Image-to-image translation
  • Style transfer (content loss)
  • Super-resolution

style_loss¤

Gram matrix-based style matching.

from artifex.generative_models.core.losses.perceptual import style_loss

# Captures texture/style information
loss = style_loss(
    features_real,
    features_fake,
    weights={"conv1": 1.0, "conv2": 1.0, "conv3": 1.0}
)

How it works:

  1. Computes Gram matrices of features
  2. Measures distance between Gram matrices
  3. Captures correlations between feature channels

Use Cases:

  • Style transfer
  • Texture synthesis
  • Artistic image generation

contextual_loss¤

Robust to spatial misalignments, measures distributional similarity.

from artifex.generative_models.core.losses.perceptual import contextual_loss

# Single-layer features
feat_real = jnp.ones((2, 32, 32, 128))
feat_fake = jnp.zeros((2, 32, 32, 128))

loss = contextual_loss(
    feat_real,
    feat_fake,
    band_width=0.1,
    max_samples=512  # Memory-efficient
)

Use Cases:

  • Non-aligned image matching
  • Texture transfer
  • Semantic image editing

PerceptualLoss Module¤

Composable NNX module combining multiple perceptual losses.

from artifex.generative_models.core.losses.perceptual import PerceptualLoss
from flax import nnx

# Create perceptual loss module
perceptual = PerceptualLoss(
    feature_extractor=vgg_model,  # Pre-trained VGG
    layer_weights={
        "conv1_2": 1.0,
        "conv2_2": 1.0,
        "conv3_3": 1.0,
        "conv4_3": 1.0,
    },
    content_weight=1.0,
    style_weight=10.0,
    contextual_weight=0.1,
)

# Compute combined loss
loss = perceptual(pred_images, target_images)

Explicit Loss Composition¤

Artifex keeps multi-term objectives explicit. Compose primitives directly and return a plain metrics dictionary from the owning model or trainer.

from artifex.generative_models.core.losses.reconstruction import mse_loss, mae_loss

reconstruction_loss = mse_loss(predictions, targets)
l1_penalty = mae_loss(predictions, targets)
total_loss = reconstruction_loss + 0.5 * l1_penalty

loss_dict = {
    "total_loss": total_loss,
    "reconstruction_loss": reconstruction_loss,
    "l1_penalty": l1_penalty,
}

Use explicit schedule weights for curriculum learning:

def warmup_schedule(step: int) -> float:
    return min(1.0, step / 1000.0)

perceptual_term = perceptual_loss(predictions, targets)
scheduled_perceptual = warmup_schedule(step) * perceptual_term
total_loss = reconstruction_loss + 0.1 * scheduled_perceptual

Base Utilities¤

Helper classes and functions for loss management.

Location: src/artifex/generative_models/core/losses/base.py

reduce_loss¤

Shared reduction helper for loss tensors.

from artifex.generative_models.core.losses.base import reduce_loss

per_pixel_loss = jnp.square(predictions - targets)

mean_loss = reduce_loss(per_pixel_loss, reduction="mean")
sum_loss = reduce_loss(per_pixel_loss, reduction="sum")
vae_loss = reduce_loss(per_pixel_loss, reduction="batch_sum")

Common Patterns¤

Pattern 1: VAE Loss¤

from artifex.generative_models.core.losses.reconstruction import mse_loss
from artifex.generative_models.core.losses.divergence import gaussian_kl_divergence

def vae_loss(x, reconstruction, mean, logvar):
    """Complete VAE loss."""
    # Canonical ELBO terms use VAE-style batch_sum reduction.
    recon_loss = mse_loss(reconstruction, x, reduction="batch_sum")
    kl_loss = gaussian_kl_divergence(mean, logvar, reduction="batch_sum")

    # Total loss
    total_loss = recon_loss + 0.5 * kl_loss

    return {
        "total_loss": total_loss,
        "reconstruction_loss": recon_loss,
        "kl_loss": kl_loss,
    }

Pattern 2: GAN with Multiple Losses¤

from artifex.generative_models.core.losses.adversarial import (
    hinge_generator_loss,
    hinge_discriminator_loss
)
from artifex.generative_models.core.losses.perceptual import PerceptualLoss

# Perceptual loss for generator
perceptual_loss = PerceptualLoss(
    feature_extractor=vgg,
    content_weight=1.0,
    style_weight=10.0,
)

def generator_loss(fake_images, real_images, fake_scores):
    """Complete generator loss."""
    # Adversarial loss
    adv_loss = hinge_generator_loss(fake_scores)

    # Perceptual loss
    perc_loss = perceptual_loss(fake_images, real_images)

    # Total loss
    total = adv_loss + 0.1 * perc_loss

    return {
        "total_loss": total,
        "adversarial": adv_loss,
        "perceptual": perc_loss,
    }

Pattern 3: Explicit Multi-Loss System¤

recon_loss = mse_loss(predictions, targets)
perc_loss = perceptual_fn(predictions, targets)
total_loss = recon_loss + 0.1 * perc_loss

components = {
    "reconstruction_loss": recon_loss,
    "perceptual_loss": perc_loss,
}

Best Practices¤

DO¤

  • ✅ Use reduction="batch_sum" for VAE ELBO terms
  • ✅ Use reduction="mean" for ordinary regression-style losses
  • ✅ Scale losses to similar magnitudes when combining
  • ✅ Use perceptual losses for visual quality
  • ✅ Monitor individual loss components during training
  • ✅ Prefer the numerically stable Artifex and CalibraX primitives directly
  • ✅ Normalize inputs when using distance-based losses
  • ✅ Use gradient clipping with adversarial losses

DON'T¤

  • ❌ Mix different reduction types without careful consideration
  • ❌ Use very different loss magnitudes without weighting
  • ❌ Forget to normalize features for perceptual losses
  • ❌ Use high-dimensional MMD without subsampling
  • ❌ Apply losses on unnormalized data
  • ❌ Ignore numerical stability (use eps parameters)

Performance Tips¤

Memory-Efficient Perceptual Loss¤

# Use lower resolution or fewer samples
perceptual = PerceptualLoss(
    max_contextual_samples=256,  # Reduce memory usage
    ...
)

Vectorized Loss Computation¤

# Use JAX's vmap for batched loss computation
from flax import nnx

@nnx.jit
def batch_loss(predictions, targets):
    return nnx.vmap(mse_loss)(predictions, targets)

Gradient Accumulation for Large Losses¤

# For very large composite losses
@nnx.jit
@nnx.grad
def loss_with_accumulation(params, batch):
    # Compute losses in chunks
    ...

Troubleshooting¤

Issue: "NaN in loss computation"¤

Solutions:

  • Use epsilon parameters for numerical stability
  • Check input normalization
  • Prefer the numerically stable divergence and reduction helpers already built into the loss primitives
  • Clip gradients

Issue: "Loss values differ greatly in magnitude"¤

Solution: Scale losses to similar ranges:

# Bad: losses differ by orders of magnitude
total = recon_loss + kl_loss  # e.g., 100.0 + 0.01

# Good: scale to similar magnitudes
total = recon_loss + 10.0 * kl_loss  # e.g., 100.0 + 0.1

Issue: "Perceptual loss uses too much memory"¤

Solution: Reduce sample count and feature resolution:

contextual_loss(
    features_real,
    features_fake,
    max_samples=256,  # Reduce from default 1024
)

Next Steps¤

References¤