Skip to content

DPO Trainer¤

Status: Supported runtime training surface

Module: artifex.generative_models.training.rl.dpo

Source: src/artifex/generative_models/training/rl/dpo.py

DPOTrainer scores typed preference batches built from sequence-native generation contracts.

Quick Start¤

from flax import nnx
import optax

from artifex.generative_models.training import DPOConfig, DPOTrainer
from artifex.generative_models.training.rl import (
    GeneratedSequenceBatch,
    PreferenceBatch,
)

model = PolicyModel(rngs=nnx.Rngs(0))
reference_model = PolicyModel(rngs=nnx.Rngs(1))
optimizer = nnx.Optimizer(model, optax.adam(1e-5), wrt=nnx.Param)

trainer = DPOTrainer(
    model=model,
    reference_model=reference_model,
    optimizer=optimizer,
    config=DPOConfig(beta=0.1, label_smoothing=0.0),
)

batch = PreferenceBatch(
    chosen=GeneratedSequenceBatch.from_sequences(
        chosen_sequences,
        response_mask=chosen_loss_mask,
    ),
    rejected=GeneratedSequenceBatch.from_sequences(
        rejected_sequences,
        response_mask=rejected_loss_mask,
    ),
)

loss, metrics = trainer.train_step(batch)

Typed Preference Contract¤

DPOTrainer consumes PreferenceBatch[GeneratedSequenceBatch].

  • PreferenceBatch keeps chosen and rejected samples aligned
  • GeneratedSequenceBatch carries the token sequences and optional response_mask
  • chosen_loss_mask and rejected_loss_mask should be converted into the corresponding response_mask values on each side

SimPO / Reference-Free Mode¤

trainer = DPOTrainer(
    model=model,
    reference_model=None,
    optimizer=optimizer,
    config=DPOConfig(beta=0.1, reference_free=True),
)

Configuration¤

artifex.generative_models.training.rl.configs.DPOConfig dataclass ¤

DPOConfig(beta: float = 0.1, label_smoothing: float = 0.0, reference_free: bool = False)

Configuration for Direct Preference Optimization.

DPO enables preference learning without an explicit reward model. SimPO mode (reference_free=True) eliminates the need for a reference model.

Attributes:

Name Type Description
beta float

Reward scaling parameter. Higher values = stronger preference. Default 0.1.

label_smoothing float

Label smoothing for preference loss. Default 0.0.

reference_free bool

Whether to use SimPO-style reference-free training. When True, no reference model is needed. Default False.

beta class-attribute instance-attribute ¤

beta: float = 0.1

label_smoothing class-attribute instance-attribute ¤

label_smoothing: float = 0.0

reference_free class-attribute instance-attribute ¤

reference_free: bool = False