Training Profiling¤
Profiling is configured at the shared Trainer layer through CallbackList.
Use the trace profiler for short focused windows and the memory profiler for
longer-running visibility into device usage.
JAX Trace Profiling¤
from artifex.generative_models.training import Trainer
from artifex.generative_models.training.callbacks import (
CallbackList,
JAXProfiler,
ProfilingConfig,
)
callbacks = CallbackList(
[
JAXProfiler(
ProfilingConfig(
log_dir="logs/profiles",
start_step=10,
end_step=20,
)
)
]
)
trainer = Trainer(
model=model,
training_config=training_config,
loss_fn=loss_fn,
callbacks=callbacks,
)
trainer.train(train_data=train_data, num_epochs=10, batch_size=64)
ProfilingConfig only exposes:
| Parameter | Description |
|---|---|
log_dir |
Output directory for trace files |
start_step |
First profiled training step |
end_step |
Final profiled training step |
Memory Profiling¤
from artifex.generative_models.training.callbacks import (
CallbackList,
MemoryProfiler,
MemoryProfileConfig,
)
callbacks = CallbackList(
[
MemoryProfiler(
MemoryProfileConfig(
log_dir="logs/memory",
profile_every_n_steps=100,
log_device_memory=True,
)
)
]
)
Recommendations¤
- Put
start_stepafter initial compilation warmup. - Keep the trace window small.
- Use memory snapshots for long jobs where full tracing would be too expensive.
- Combine profiling with logging callbacks only when you need both signals in the same run.