PPO Trainer¤
Status: Supported runtime training surface
Module: artifex.generative_models.training.rl.ppo
Source: src/artifex/generative_models/training/rl/ppo.py
PPOTrainer operates on typed autoregressive rollout batches with explicit
old-policy log probabilities, returns, and advantages.
Quick Start¤
from flax import nnx
import optax
from artifex.generative_models.training import PPOConfig, PPOTrainer
from artifex.generative_models.training.rl import (
GeneratedBatch,
GeneratedSequenceBatch,
SequenceRolloutBatch,
)
model = ActorCriticModel(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(3e-4), wrt=nnx.Param)
trainer = PPOTrainer(model, optimizer, PPOConfig())
sequence_batch = GeneratedSequenceBatch(
generation=GeneratedBatch(outputs=token_sequences),
response_mask=response_mask,
)
rollout = SequenceRolloutBatch(
sequence_batch=sequence_batch,
old_log_probs=old_log_probs,
returns=returns,
advantages=advantages,
dones=dones,
)
loss, metrics = trainer.train_step(rollout)
Typed Batch Contract¤
PPOTrainer requires a SequenceRolloutBatch with:
old_log_probs=aligned tosequences[:, 1:]returns=aligned to the same action-token layoutadvantages=aligned to the same action-token layout
The sequence wrapper comes from GeneratedSequenceBatch, which itself wraps the
generic GeneratedBatch.
Configuration¤
artifex.generative_models.training.rl.configs.PPOConfig
dataclass
¤
PPOConfig(gamma: float = 0.99, gae_lambda: float = 0.95, clip_param: float = 0.2, vf_coeff: float = 0.5, entropy_coeff: float = 0.01, max_grad_norm: float = 0.5)
Configuration for Proximal Policy Optimization.
Implements PPO with clipped surrogate objective and GAE.
Attributes:
| Name | Type | Description |
|---|---|---|
gamma |
float
|
Discount factor for computing returns. Default 0.99. |
gae_lambda |
float
|
Lambda for Generalized Advantage Estimation. Default 0.95. |
clip_param |
float
|
Clipping parameter for surrogate objective. Default 0.2. |
vf_coeff |
float
|
Coefficient for value function loss. Default 0.5. |
entropy_coeff |
float
|
Coefficient for entropy bonus. Default 0.01. |
max_grad_norm |
float
|
Maximum global gradient norm for clipping. Default 0.5. |