Skip to content

PPO Trainer¤

Status: Supported runtime training surface

Module: artifex.generative_models.training.rl.ppo

Source: src/artifex/generative_models/training/rl/ppo.py

PPOTrainer operates on typed autoregressive rollout batches with explicit old-policy log probabilities, returns, and advantages.

Quick Start¤

from flax import nnx
import optax

from artifex.generative_models.training import PPOConfig, PPOTrainer
from artifex.generative_models.training.rl import (
    GeneratedBatch,
    GeneratedSequenceBatch,
    SequenceRolloutBatch,
)

model = ActorCriticModel(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(3e-4), wrt=nnx.Param)
trainer = PPOTrainer(model, optimizer, PPOConfig())

sequence_batch = GeneratedSequenceBatch(
    generation=GeneratedBatch(outputs=token_sequences),
    response_mask=response_mask,
)
rollout = SequenceRolloutBatch(
    sequence_batch=sequence_batch,
    old_log_probs=old_log_probs,
    returns=returns,
    advantages=advantages,
    dones=dones,
)

loss, metrics = trainer.train_step(rollout)

Typed Batch Contract¤

PPOTrainer requires a SequenceRolloutBatch with:

  • old_log_probs= aligned to sequences[:, 1:]
  • returns= aligned to the same action-token layout
  • advantages= aligned to the same action-token layout

The sequence wrapper comes from GeneratedSequenceBatch, which itself wraps the generic GeneratedBatch.

Configuration¤

artifex.generative_models.training.rl.configs.PPOConfig dataclass ¤

PPOConfig(gamma: float = 0.99, gae_lambda: float = 0.95, clip_param: float = 0.2, vf_coeff: float = 0.5, entropy_coeff: float = 0.01, max_grad_norm: float = 0.5)

Configuration for Proximal Policy Optimization.

Implements PPO with clipped surrogate objective and GAE.

Attributes:

Name Type Description
gamma float

Discount factor for computing returns. Default 0.99.

gae_lambda float

Lambda for Generalized Advantage Estimation. Default 0.95.

clip_param float

Clipping parameter for surrogate objective. Default 0.2.

vf_coeff float

Coefficient for value function loss. Default 0.5.

entropy_coeff float

Coefficient for entropy bonus. Default 0.01.

max_grad_norm float

Maximum global gradient norm for clipping. Default 0.5.

gamma class-attribute instance-attribute ¤

gamma: float = 0.99

gae_lambda class-attribute instance-attribute ¤

gae_lambda: float = 0.95

clip_param class-attribute instance-attribute ¤

clip_param: float = 0.2

vf_coeff class-attribute instance-attribute ¤

vf_coeff: float = 0.5

entropy_coeff class-attribute instance-attribute ¤

entropy_coeff: float = 0.01

max_grad_norm class-attribute instance-attribute ¤

max_grad_norm: float = 0.5