Skip to content

Generative Models¤

The main module containing all generative model implementations, core infrastructure, modalities, extensions, and training systems in Artifex.

Overview¤

  • Model Architectures


    VAE, GAN, Diffusion, Flow, EBM, and Autoregressive models

  • Core Infrastructure


    Configuration, losses, distributions, sampling, and metrics

  • Modalities


    Image, text, audio, protein, and multimodal support

  • Extensions


    Domain-specific extensions for proteins, geometric data, and more

Quick Start¤

Creating a Model¤

from artifex.generative_models.factory import create_model
from artifex.generative_models.core.configuration import DecoderConfig, EncoderConfig, VAEConfig
from flax import nnx

# Create VAE model
encoder = EncoderConfig(
    name="vae_encoder",
    input_shape=(28, 28, 1),
    latent_dim=64,
    hidden_dims=(256, 128),
)
decoder = DecoderConfig(
    name="vae_decoder",
    output_shape=(28, 28, 1),
    latent_dim=64,
    hidden_dims=(128, 256),
)
config = VAEConfig(
    name="my_vae",
    encoder=encoder,
    decoder=decoder,
)

rngs = nnx.Rngs(0)
model = create_model(config, rngs=rngs)

Training a Model¤

from artifex.generative_models.training import VAETrainer

trainer = VAETrainer(
    model=model,
    config=training_config,
    train_dataset=train_data,
)

trainer.train()

Generating Samples¤

There is no top-level artifex.inference namespace or one shared inference pipeline in the current runtime. Generation remains family-owned.

# `model` above is a VAE built from `VAEConfig`.
samples = model.sample(num_samples=16)

See Inference Overview for the retained loading and generation workflow, and Inference Reference for the one shared production-optimization surface.

Module Structure¤

The generative_models package is organized into the following submodules:

Models¤

Implementation of all generative model architectures.

Model Type Description
VAE Variational Autoencoders
GAN Generative Adversarial Networks
Diffusion Denoising Diffusion Models
Flow Normalizing Flow Models
EBM Energy-Based Models
Autoregressive Autoregressive Models

Models Reference

Core¤

Foundational abstractions and utilities.

Component Description
Configuration Unified configuration system
Losses Loss functions
Distributions Probability distributions
Sampling Sampling methods
Metrics Evaluation metrics
Layers Neural network layers

Core Reference

Modalities¤

Registry-backed modalities plus family-scoped owner pages.

Use Modalities Overview for the retained registry-backed surface and the owner pages below only for family-scoped helper details.

Reference Description
Registry Owner Shared registry-backed surface for image, molecular, and protein
Timeseries Base Timeseries helper owner page
Timeseries Datasets Synthetic timeseries data factories
Protein Modality Protein-specific adapter and extension lookup
Protein Losses Protein structure loss builders

Modalities Overview

Training¤

Training infrastructure and utilities.

Component Description
VAE Trainer VAE model trainer
GAN Trainer GAN model trainer
Diffusion Trainer Diffusion model trainer
Checkpoint Training checkpointing
Data Parallel Multi-device training
AdamW Optimization algorithms
Scheduler Learning rate schedules

Training Reference

Extensions¤

Domain-specific extensions with one curated overview plus live owner pages.

Use Extensions Overview for the curated scope and the owner pages below for live module details.

Reference Description
Base Extensions Shared extension hierarchy and base contracts
Registry Owner Registry enum, discovery helpers, and factory surface
Protein Constraints Protein constraint owners and measurement helpers
NLP Embeddings RoPE, sinusoidal, and text embedding owners
Audio Analysis Temporal audio-analysis owner page

Extensions Overview

Factory¤

Model creation and registration.

from artifex.generative_models.factory import create_model, create_model_with_extensions

# `config` is a family-specific typed config such as VAEConfig, DDPMConfig,
# WGANConfig, or PointCloudConfig.

# Create a model from a dataclass config
model = create_model(config, rngs=rngs)

# Create a model with extensions
model, extensions = create_model_with_extensions(
    config,
    extensions_config=extension_configs,
    rngs=rngs,
)

Factory Reference

Architecture¤

generative_models/
├── core/                 # Core infrastructure
│   ├── configuration/    # Configuration system
│   ├── losses/           # Loss functions
│   ├── distributions/    # Probability distributions
│   ├── sampling/         # Sampling methods
│   ├── metrics/          # Evaluation metrics
│   └── layers/           # Neural network layers
├── models/               # Model implementations
│   ├── vae/              # VAE variants
│   ├── gan/              # GAN variants
│   ├── diffusion/        # Diffusion models
│   ├── flow/             # Flow models
│   ├── ebm/              # Energy-based models
│   └── autoregressive/   # Autoregressive models
├── modalities/           # Data modality support
│   ├── image/            # Image modality
│   ├── text/             # Text modality
│   ├── audio/            # Audio modality
│   ├── protein/          # Protein modality
│   └── multi_modal/      # Multimodal support
├── training/             # Training infrastructure
│   ├── trainers/         # Model trainers
│   ├── callbacks/        # Training callbacks
│   ├── distributed/      # Distributed training
│   └── optimizers/       # Optimizers and schedulers
├── extensions/           # Domain extensions
│   ├── protein/          # Protein modeling
│   └── geometric/        # Geometric deep learning
├── factory/              # Model creation
│   ├── core.py           # ModelFactory implementation
│   ├── registry.py       # Builder registry
│   └── builders/         # Model-family builders
└── zoo/                  # Retired preset compatibility boundary

Design Principles¤

1. Protocol-Based Interfaces¤

All components use Python Protocols for type-safe interfaces:

from artifex.generative_models.core.protocols import GenerativeModel

class MyModel(GenerativeModel):
    def generate(self, num_samples: int, **kwargs) -> jax.Array:
        ...

    def loss_fn(self, batch: jax.Array, **kwargs) -> jax.Array:
        ...

2. Unified Configuration¤

All models use the unified configuration system:

from artifex.generative_models.core.configuration import VAEConfig

config = VAEConfig(
    name="my_vae",
    latent_dim=64,
    # Type-safe, validated configuration
)

3. Modality-Agnostic Design¤

Models work with any data modality through adapters:

from artifex.generative_models.modalities import get_modality

# Get modality handler
image_modality = get_modality("image", rngs=rngs)

# Adapt model for modality
adapted_model = image_modality.get_adapter("vae").adapt(model)

4. Hardware-Aware¤

All components are hardware-aware with automatic device management:

from artifex.generative_models.core import DeviceManager

device_manager = DeviceManager()
device = device_manager.get_device()  # Auto-selects GPU/CPU