Skip to content

GRPO Trainer¤

Status: Supported runtime training surface

Module: artifex.generative_models.training.rl.grpo

Source: src/artifex/generative_models/training/rl/grpo.py

GRPOTrainer consumes grouped typed rollout batches. The runtime expects prompt-group structure to be explicit through GroupRolloutBatch.

Quick Start¤

from flax import nnx
import optax

from artifex.generative_models.training import GRPOConfig, GRPOTrainer
from artifex.generative_models.training.rl import (
    GeneratedBatch,
    GeneratedSequenceBatch,
    GroupRolloutBatch,
    SequenceRolloutBatch,
)

model = PolicyModel(rngs=nnx.Rngs(0))
reference_model = PolicyModel(rngs=nnx.Rngs(1))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)

trainer = GRPOTrainer(
    model,
    optimizer,
    GRPOConfig(num_generations=4, clip_param=0.2, beta=0.01),
    reference_model=reference_model,
)

sequence_batch = GeneratedSequenceBatch(
    generation=GeneratedBatch(
        outputs=token_sequences,
        rewards=sequence_rewards,
    ),
    response_mask=response_mask,
)
rollout = SequenceRolloutBatch(
    sequence_batch=sequence_batch,
    old_log_probs=old_log_probs,
)
grouped_rollout = GroupRolloutBatch(
    rollout=rollout,
    group_size=4,
)

loss, metrics = trainer.train_step(grouped_rollout)

Typed Batch Contract¤

GRPOTrainer expects:

  • a SequenceRolloutBatch carrying old_log_probs=
  • sequence-level rewards on the wrapped GeneratedBatch via sequence_rewards=
  • a GroupRolloutBatch with explicit group_size=

When a reference_model is provided, KL regularization is computed from the reference policy directly. Batch-level reference log probabilities are not part of the public GRPO surface.

Configuration¤

artifex.generative_models.training.rl.configs.GRPOConfig dataclass ¤

GRPOConfig(num_generations: int = 4, clip_param: float = 0.2, beta: float = 0.01, entropy_coeff: float = 0.01)

Configuration for Group Relative Policy Optimization.

GRPO is a critic-free RL algorithm from DeepSeek-R1 that: - Generates multiple completions per prompt (num_generations) - Normalizes advantages within each group - Uses PPO-style clipping - Saves ~50% memory by eliminating the value network

Attributes:

Name Type Description
num_generations int

Number of completions to generate per prompt. Default 4.

clip_param float

Clipping parameter for surrogate objective. Default 0.2.

beta float

KL penalty coefficient for regularization. Default 0.01.

entropy_coeff float

Coefficient for entropy bonus. Default 0.01.

num_generations class-attribute instance-attribute ¤

num_generations: int = 4

clip_param class-attribute instance-attribute ¤

clip_param: float = 0.2

beta class-attribute instance-attribute ¤

beta: float = 0.01

entropy_coeff class-attribute instance-attribute ¤

entropy_coeff: float = 0.01