GRPO Trainer¤
Status: Supported runtime training surface
Module: artifex.generative_models.training.rl.grpo
Source: src/artifex/generative_models/training/rl/grpo.py
GRPOTrainer consumes grouped typed rollout batches. The runtime expects
prompt-group structure to be explicit through GroupRolloutBatch.
Quick Start¤
from flax import nnx
import optax
from artifex.generative_models.training import GRPOConfig, GRPOTrainer
from artifex.generative_models.training.rl import (
GeneratedBatch,
GeneratedSequenceBatch,
GroupRolloutBatch,
SequenceRolloutBatch,
)
model = PolicyModel(rngs=nnx.Rngs(0))
reference_model = PolicyModel(rngs=nnx.Rngs(1))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
trainer = GRPOTrainer(
model,
optimizer,
GRPOConfig(num_generations=4, clip_param=0.2, beta=0.01),
reference_model=reference_model,
)
sequence_batch = GeneratedSequenceBatch(
generation=GeneratedBatch(
outputs=token_sequences,
rewards=sequence_rewards,
),
response_mask=response_mask,
)
rollout = SequenceRolloutBatch(
sequence_batch=sequence_batch,
old_log_probs=old_log_probs,
)
grouped_rollout = GroupRolloutBatch(
rollout=rollout,
group_size=4,
)
loss, metrics = trainer.train_step(grouped_rollout)
Typed Batch Contract¤
GRPOTrainer expects:
- a
SequenceRolloutBatchcarryingold_log_probs= - sequence-level rewards on the wrapped
GeneratedBatchviasequence_rewards= - a
GroupRolloutBatchwith explicitgroup_size=
When a reference_model is provided, KL regularization is computed from the
reference policy directly. Batch-level reference log probabilities are not part
of the public GRPO surface.
Configuration¤
artifex.generative_models.training.rl.configs.GRPOConfig
dataclass
¤
GRPOConfig(num_generations: int = 4, clip_param: float = 0.2, beta: float = 0.01, entropy_coeff: float = 0.01)
Configuration for Group Relative Policy Optimization.
GRPO is a critic-free RL algorithm from DeepSeek-R1 that: - Generates multiple completions per prompt (num_generations) - Normalizes advantages within each group - Uses PPO-style clipping - Saves ~50% memory by eliminating the value network
Attributes:
| Name | Type | Description |
|---|---|---|
num_generations |
int
|
Number of completions to generate per prompt. Default 4. |
clip_param |
float
|
Clipping parameter for surrogate objective. Default 0.2. |
beta |
float
|
KL penalty coefficient for regularization. Default 0.01. |
entropy_coeff |
float
|
Coefficient for entropy bonus. Default 0.01. |