Skip to content

Model Factory¤

The factory module provides a centralized, type-safe interface for creating generative models in Artifex. It uses dataclass-based configurations to determine model type automatically, eliminating the need for string-based model class specifications.

Overview¤

  • Unified Interface


    Single create_model() function for all model types

  • Type-Safe


    Dataclass configs with automatic validation

  • Extensible


    Register custom builders for new model types

  • Modality Support


    Optional modality adaptation for domain-specific models

Quick Start¤

Basic Model Creation¤

from artifex.generative_models.factory import create_model
from artifex.generative_models.core.configuration import (
    VAEConfig,
    EncoderConfig,
    DecoderConfig,
)
from flax import nnx

# Create configuration
encoder = EncoderConfig(
    name="encoder",
    input_shape=(28, 28, 1),
    latent_dim=32,
    hidden_dims=(256, 128),
    activation="relu",
)
decoder = DecoderConfig(
    name="decoder",
    output_shape=(28, 28, 1),
    latent_dim=32,
    hidden_dims=(128, 256),
    activation="relu",
)
config = VAEConfig(
    name="my_vae",
    encoder=encoder,
    decoder=decoder,
    kl_weight=1.0,
)

# Create model - type is inferred from config
rngs = nnx.Rngs(params=42, dropout=43, sample=44)
model = create_model(config, rngs=rngs)

Model Type Inference¤

The factory automatically infers model type from the configuration class:

Config Class Model Type Created Model
VAEConfig, BetaVAEConfig, ConditionalVAEConfig, VQVAEConfig vae VAE variants
DCGANConfig, WGANConfig, LSGANConfig, ConditionalGANConfig, CycleGANConfig gan Concrete GAN variants
DiffusionConfig, DDPMConfig, ScoreDiffusionConfig diffusion Diffusion models
EBMConfig, DeepEBMConfig ebm Energy-based models
FlowConfig flow Normalizing flows
AutoregressiveConfig, TransformerConfig, PixelCNNConfig, WaveNetConfig autoregressive Autoregressive models
GeometricConfig, PointCloudConfig, MeshConfig, VoxelConfig, GraphConfig geometric Geometric models

API Reference¤

create_model¤

The main function for model creation.

def create_model(
    config: DataclassConfig,
    *,
    modality: str | None = None,
    rngs: nnx.Rngs,
    **kwargs,
) -> Any:
    """Create a model from configuration.

    Args:
        config: Dataclass model configuration (DDPMConfig, VAEConfig, etc.)
        modality: Optional modality for adaptation ('image', 'molecular', or 'protein')
        rngs: Random number generators
        **kwargs: Additional arguments passed to the builder

    Returns:
        Created model instance

    Raises:
        TypeError: If config is not a supported dataclass config
        ValueError: If builder not found for model type
    """

Example:

from artifex.generative_models.factory import create_model
from artifex.generative_models.core.configuration import DDPMConfig, UNetBackboneConfig, NoiseScheduleConfig

# Create diffusion model config
backbone = UNetBackboneConfig(
    name="unet",
    in_channels=3,
    out_channels=3,
    base_channels=64,
    channel_mults=(1, 2, 4),
)
noise_schedule = NoiseScheduleConfig(
    name="schedule",
    schedule_type="linear",
    num_timesteps=1000,
    beta_start=1e-4,
    beta_end=2e-2,
)
config = DDPMConfig(
    name="ddpm",
    input_shape=(3, 32, 32),
    backbone=backbone,
    noise_schedule=noise_schedule,
)

# Create model
model = create_model(config, rngs=rngs)

create_model_with_extensions¤

Create a model with extensions for enhanced functionality.

def create_model_with_extensions(
    config: DataclassConfig,
    *,
    extensions_config: dict[str, ExtensionConfig] | None = None,
    modality: str | None = None,
    rngs: nnx.Rngs,
    **kwargs,
) -> tuple[Any, dict[str, ModelExtension]]:
    """Create a model and its extensions from configuration.

    Returns:
        Tuple of (model, extensions_dict)
    """

Example:

from artifex.generative_models.factory import create_model_with_extensions

# Create model with extensions
model, extensions = create_model_with_extensions(
    config,
    extensions_config={
        "augmentation": augmentation_config,
        "regularization": reg_config,
    },
    rngs=rngs,
)

ModelFactory¤

The underlying factory class for advanced usage.

class ModelFactory:
    """Centralized factory for all generative models."""

    def __init__(self):
        """Initialize with default builders."""

    def create(
        self,
        config: DataclassConfig,
        *,
        modality: str | None = None,
        rngs: nnx.Rngs,
        **kwargs,
    ) -> Any:
        """Create a model from dataclass configuration."""

Builders¤

Each model family has a dedicated builder that handles model instantiation:

VAE Builder¤

Creates VAE variants based on configuration type:

  • VAEConfigVAE
  • BetaVAEConfigBetaVAE
  • ConditionalVAEConfigConditionalVAE
  • VQVAEConfigVQVAE

VAE Builder Reference

GAN Builder¤

Creates GAN variants:

  • DCGANConfigDCGAN
  • WGANConfigWGAN
  • LSGANConfigLSGAN
  • ConditionalGANConfigConditionalGAN
  • CycleGANConfigCycleGAN

Base GANConfig is not factory-ready and is rejected by create_model(...).

GAN Builder Reference

Diffusion Builder¤

Creates diffusion models:

  • DDPMConfigDDPMModel
  • ScoreDiffusionConfigScoreDiffusionModel
  • DiffusionConfigDiffusionModel

Diffusion Builder Reference

Flow Builder¤

Creates normalizing flows:

  • FlowConfigNormalizingFlow

Flow Builder Reference

EBM Builder¤

Creates energy-based models:

  • EBMConfigEBM
  • DeepEBMConfigDeepEBM

EBM Builder Reference

Autoregressive Builder¤

Creates autoregressive models:

  • TransformerConfigTransformer
  • PixelCNNConfigPixelCNN
  • WaveNetConfigWaveNet

Autoregressive Builder Reference

Geometric Builder¤

Creates geometric models:

  • PointCloudConfigPointCloudModel
  • MeshConfigMeshModel
  • VoxelConfigVoxelModel
  • GraphConfigGraphModel

Geometric Builder Reference

Modality Adaptation¤

The factory supports optional modality adaptation for domain-specific models:

# Create image-adapted model
model = create_model(config, modality="image", rngs=rngs)

# Create molecular-adapted model
model = create_model(config, modality="molecular", rngs=rngs)

# Create protein-adapted model
model = create_model(config, modality="protein", rngs=rngs)

Available Modalities:

  • image: Convolutional layers, FID/IS metrics
  • molecular: Chemical constraints and pharmacophore features
  • protein: Structure prediction, sequence modeling

Custom Builders¤

Create custom builders for new model families inside Artifex:

from flax import nnx
from typing import Any

class CustomBuilder:
    """Builder for custom model type."""

    def build(self, config: Any, *, rngs: nnx.Rngs, **kwargs):
        """Build the model from configuration."""
        return CustomModel(config, rngs=rngs, **kwargs)

# Register inside the canonical factory implementation
from artifex.generative_models.factory.core import ModelFactory

factory = ModelFactory()
factory.registry.register("custom", CustomBuilder())

The builder registry is an implementation detail for extending Artifex itself. Normal package users should stay on create_model() and create_model_with_extensions().

Best Practices¤

DO¤

  • ✅ Use dataclass configs instead of dictionaries
  • ✅ Validate configs before passing to factory
  • ✅ Use type hints for better IDE support
  • ✅ Pass all required RNG streams to nnx.Rngs

DON'T¤

  • ❌ Pass dictionary configs (will raise TypeError)
  • ❌ Use string-based model class specifications
  • ❌ Forget to provide rngs parameter

Error Handling¤

The factory provides clear error messages:

# TypeError: Dictionary configs not supported
create_model({"model_class": "vae"}, rngs=rngs)
# Raises: TypeError: Expected dataclass config, got dict.

# TypeError: Unknown config type
create_model(UnknownConfig(), rngs=rngs)
# Raises: TypeError: Unknown config type: UnknownConfig

# ValueError: Builder not found
# (Only possible when the internal builder registry and config dispatch disagree)

Module Reference¤

Module Description
core Core factory implementation and create_model function
vae VAE model builder
gan GAN model builder
diffusion Diffusion model builder
flow Normalizing flow builder
ebm Energy-based model builder
autoregressive Autoregressive model builder
geometric Geometric model builder