Flow Trainer¤
Status: Supported runtime training surface
Module: artifex.generative_models.training.trainers.flow_trainer
Source: src/artifex/generative_models/training/trainers/flow_trainer.py
FlowTrainer implements the flow-matching runtime that Artifex actually ships:
linear Gaussian-noise-to-data interpolation plus configurable time sampling.
Quick Start¤
from flax import nnx
import jax
import optax
from artifex.generative_models.training.trainers import (
FlowTrainer,
FlowTrainingConfig,
)
model = create_flow_model(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-4), wrt=nnx.Param)
trainer = FlowTrainer(
FlowTrainingConfig(
time_sampling="logit_normal",
logit_normal_loc=0.0,
logit_normal_scale=1.0,
)
)
key = jax.random.key(0)
loss, metrics = trainer.train_step(model, optimizer, batch, key)
Configuration¤
artifex.generative_models.training.trainers.flow_trainer.FlowTrainingConfig
dataclass
¤
FlowTrainingConfig(time_sampling: Literal['uniform', 'logit_normal', 'u_shaped'] = 'uniform', logit_normal_loc: float = 0.0, logit_normal_scale: float = 1.0)
Configuration for flow matching training.
Attributes:
| Name | Type | Description |
|---|---|---|
time_sampling |
Literal['uniform', 'logit_normal', 'u_shaped']
|
How to sample time values during training. - "uniform": Uniform sampling in [0, 1] - "logit_normal": Logit-normal (favors middle times) - "u_shaped": U-shaped (favors interpolation endpoints) |
logit_normal_loc |
float
|
Location parameter for logit-normal sampling. |
logit_normal_scale |
float
|
Scale parameter for logit-normal sampling. |
Runtime-Active Fields¤
| Parameter | Default | Description |
|---|---|---|
time_sampling |
"uniform" |
Time distribution used for interpolation samples |
logit_normal_loc |
0.0 |
Mean of the latent normal before the logistic transform |
logit_normal_scale |
1.0 |
Scale of the latent normal before the logistic transform |
Time Sampling Strategies¤
Uniform¤
Logit-Normal¤
U-Shaped¤
Objective¤
The trainer uses the linear interpolation path
with target velocity
and minimizes mean-squared error between the model prediction and u_t.
Shared Trainer Integration¤
FlowTrainer can also provide a step-aware objective for the shared Trainer:
from artifex.generative_models.training import Trainer
from artifex.generative_models.training.callbacks import CallbackList
flow_trainer = FlowTrainer(FlowTrainingConfig(time_sampling="logit_normal"))
trainer = Trainer(
model=model,
training_config=training_config,
loss_fn=flow_trainer.create_loss_fn(),
callbacks=CallbackList([]),
)
Model Contract¤
The model is expected to implement:
where x_t matches the sample shape and t is a (batch,) tensor of sampled
times.