Skip to content

Autoregressive Trainer¤

Status: Supported runtime training surface

Module: artifex.generative_models.training.trainers.autoregressive_trainer

Source: src/artifex/generative_models/training/trainers/autoregressive_trainer.py

AutoregressiveTrainer is the retained sequence-training owner for teacher forcing, scheduled sampling, label smoothing, and causal or padding masks. The caller still owns the model, optimizer, and outer loop.

Quick Start¤

from flax import nnx
import jax
import optax

from artifex.generative_models.training.trainers import (
    AutoregressiveTrainer,
    AutoregressiveTrainingConfig,
)

model = TransformerModel(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
trainer = AutoregressiveTrainer(
    AutoregressiveTrainingConfig(
        use_teacher_forcing=True,
        scheduled_sampling="linear",
        label_smoothing=0.1,
        pad_token_id=0,
    )
)

key = jax.random.key(0)
loss, metrics = trainer.train_step(model, optimizer, batch, step=10, key=key)

JIT-Friendly Step Boundary¤

The trainer keeps model state and optimizer state explicit in train_step(...), so the step can be wrapped by nnx.jit in the caller when that is appropriate.

jit_step = nnx.jit(trainer.train_step)
loss, metrics = jit_step(model, optimizer, batch, step=10, key=key)

Mask Helpers¤

The module also exports three mask helpers:

  • create_causal_mask(seq_length)
  • create_padding_mask(tokens, pad_token_id)
  • create_combined_mask(tokens, pad_token_id=None)