Checkpointing Callbacks¤
Status: Supported runtime training surface
Module: artifex.generative_models.training.callbacks.checkpoint
Source: src/artifex/generative_models/training/callbacks/checkpoint.py
Overview¤
Model checkpointing callback that saves Orbax-managed checkpoints on the configured epoch cadence. Retention and best-checkpoint selection are handled by Orbax using the monitored metric.
Classes¤
CheckpointConfig¤
@dataclass(slots=True)
class CheckpointConfig:
"""Configuration for model checkpointing."""
dirpath: str | Path = "checkpoints"
monitor: str = "val_loss"
mode: Literal["min", "max"] = "min"
save_top_k: int = 3
every_n_epochs: int = 1
Attributes:
| Attribute | Type | Default | Description |
|---|---|---|---|
dirpath |
str \| Path |
"checkpoints" |
Directory to save checkpoints |
monitor |
str |
"val_loss" |
Metric name to monitor |
mode |
Literal["min", "max"] |
"min" |
Whether lower or higher is better |
save_top_k |
int |
3 |
Number of best checkpoints to keep (-1 = all, 0 = none) |
every_n_epochs |
int |
1 |
Save checkpoint every n epochs |
ModelCheckpoint¤
class ModelCheckpoint(BaseCallback):
"""Save model checkpoints based on monitored metrics."""
def __init__(self, config: CheckpointConfig): ...
Callback that saves model checkpoints when monitored metrics improve. Uses Orbax checkpointing infrastructure with automatic cleanup of old checkpoints. Callback that saves eligible checkpoints and delegates best-step tracking and retention to Orbax.
Key Properties:
| Property | Type | Description |
|---|---|---|
best_score |
float \| None |
Best metric value seen so far |
best_checkpoint_step |
int \| None |
Step index for the best retained checkpoint |
saved_checkpoint_steps |
list[int] |
Retained checkpoint steps managed by Orbax |
Usage¤
Basic Checkpointing¤
from artifex.generative_models.training import Trainer
from artifex.generative_models.training.callbacks import (
CallbackList,
ModelCheckpoint,
CheckpointConfig,
)
# Save best 3 checkpoints based on validation loss
checkpoint = ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
monitor="val_loss",
mode="min",
save_top_k=3,
))
trainer = Trainer(
model=model,
training_config=training_config,
loss_fn=loss_fn,
callbacks=CallbackList([checkpoint]),
)
trainer.train(train_data=train_data, num_epochs=10, batch_size=64, val_data=val_data)
# Access best checkpoint metadata after training
print(f"Best checkpoint step: {checkpoint.best_checkpoint_step}")
print(f"Best score: {checkpoint.best_score}")
Monitor Accuracy (Higher is Better)¤
checkpoint = ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
monitor="val_accuracy",
mode="max", # Higher accuracy is better
save_top_k=1, # Keep only the best
))
Save All Checkpoints¤
checkpoint = ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
save_top_k=-1, # Keep all checkpoints
every_n_epochs=5, # Save every 5 epochs
))
Combined with Other Callbacks¤
from artifex.generative_models.training.callbacks import (
CallbackList,
ModelCheckpoint,
CheckpointConfig,
EarlyStopping,
EarlyStoppingConfig,
ProgressBarCallback,
ProgressBarConfig,
)
callbacks = CallbackList([
ModelCheckpoint(CheckpointConfig(
dirpath="./checkpoints",
monitor="val_loss",
save_top_k=3,
)),
EarlyStopping(EarlyStoppingConfig(
monitor="val_loss",
patience=10,
)),
ProgressBarCallback(ProgressBarConfig()),
])
trainer = Trainer(
model=model,
training_config=training_config,
loss_fn=loss_fn,
callbacks=callbacks,
)
trainer.train(train_data=train_data, num_epochs=10, batch_size=64, val_data=val_data)
How It Works¤
- Metric Monitoring: Tracks the specified metric (
monitor) at the end of each epoch - Orbax Save: Saves the model state through the shared Orbax checkpoint utilities
- Retention Policy: Orbax keeps the configured best
save_top_kcheckpoints - Best Tracking: Orbax exposes the best retained step via
best_checkpoint_step
Integration with Orbax¤
ModelCheckpoint uses the existing Orbax-based checkpointing infrastructure:
from artifex.generative_models.core.checkpointing import (
save_checkpoint,
load_checkpoint,
setup_checkpoint_manager,
)
# Checkpoints are stored under step-numbered Orbax directories
checkpoint_manager, _ = setup_checkpoint_manager("./checkpoints")
model = load_checkpoint(checkpoint_manager, model, step=10)
See Checkpointing Guide for advanced checkpointing features including optimizer state and corruption recovery.
Module Statistics¤
- Classes: 2 (CheckpointConfig, ModelCheckpoint)
- Dependencies: Orbax checkpointing infrastructure
- Slots: Uses
__slots__for memory efficiency