Simple Image-Text Multimodal Learning¤
Status: Standalone pedagogy
Device: CPU-compatible
This walkthrough is a standalone JAX/Flax NNX concept demo. It does not instantiate shipped Artifex runtime owners.
Files¤
- Python Script:
simple_image_text.py - Jupyter Notebook:
simple_image_text.ipynb
Quick Start¤
python examples/generative_models/multimodal/simple_image_text.py
# Or open the Jupyter notebook
jupyter lab examples/generative_models/multimodal/simple_image_text.ipynb
Overview¤
This standalone walkthrough demonstrates multimodal learning by combining image and text encoders in a unified model. Learn how to build separate encoders for different modalities, create shared embedding spaces, and perform cross-modal retrieval tasks without claiming shipped Artifex multimodal owners.
Learning Objectives¤
After completing this example, you will understand:
- Multimodal model architectures with separate encoders
- Creating shared embedding spaces for multiple modalities
- Computing cross-modal similarities
- Performing cross-modal retrieval (image-to-text, text-to-image)
- Visualizing multimodal embedding spaces
Prerequisites¤
- Understanding of CNNs for image processing
- Familiarity with text embeddings and sequence models
- Knowledge of similarity metrics and representation learning
- Basic understanding of JAX/Flax NNX patterns
Theory¤
Multimodal Learning¤
Multimodal models learn joint representations from multiple input modalities. The goal is to create a shared embedding space where semantically similar inputs from different modalities are close together.
Contrastive Learning¤
The model learns by maximizing similarity between matching pairs while minimizing similarity between non-matching pairs:
where:
- \(f_I\) is the image encoder
- \(f_T\) is the text encoder
- \(\tau\) is the temperature parameter
- \(\text{sim}\) is the similarity function (typically cosine similarity)
Architecture Components¤
- Image Encoder: CNN-based encoder mapping images to embeddings
- Text Encoder: Embedding + MLP mapping text sequences to embeddings
- Fusion Layer: Combines modalities for joint predictions
Code Walkthrough¤
1. Image Encoder¤
class SimpleImageEncoder(nnx.Module):
def __init__(self, image_size=32, embed_dim=128, *, rngs: nnx.Rngs):
super().__init__()
# CNN encoder with global average pooling
self.encoder = nnx.Sequential(
nnx.Conv(3, 32, kernel_size=(3, 3), rngs=rngs),
nnx.relu,
nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs),
nnx.relu,
lambda x: jnp.mean(x, axis=(1, 2)), # Global pooling
nnx.Linear(64, embed_dim, rngs=rngs),
)
2. Text Encoder¤
class SimpleTextEncoder(nnx.Module):
def __init__(self, vocab_size=128, embed_dim=128, *, rngs: nnx.Rngs):
super().__init__()
self.embedding = nnx.Embed(vocab_size, embed_dim, rngs=rngs)
self.encoder = nnx.Sequential(
nnx.Linear(embed_dim, 64, rngs=rngs),
nnx.relu,
nnx.Linear(64, embed_dim, rngs=rngs)
)
def __call__(self, text_ids):
embedded = self.embedding(text_ids)
pooled = jnp.mean(embedded, axis=1) # Average pooling
return self.encoder(pooled)
3. Multimodal Model¤
class SimpleMultimodalModel(nnx.Module):
def __init__(self, image_size=32, vocab_size=128,
embed_dim=128, output_dim=10, *, rngs: nnx.Rngs):
super().__init__()
# Separate encoders
self.image_encoder = SimpleImageEncoder(image_size, embed_dim, rngs=rngs)
self.text_encoder = SimpleTextEncoder(vocab_size, embed_dim, rngs=rngs)
# Fusion layer
self.fusion = nnx.Sequential(
nnx.Linear(embed_dim * 2, embed_dim, rngs=rngs),
nnx.relu,
nnx.Linear(embed_dim, output_dim, rngs=rngs),
)
4. Cross-Modal Similarity¤
def compute_similarity(self, images, text_ids):
image_features = self.encode_image(images)
text_features = self.encode_text(text_ids)
# Normalize features
image_features = image_features / (
jnp.linalg.norm(image_features, axis=-1, keepdims=True) + 1e-8
)
text_features = text_features / (
jnp.linalg.norm(text_features, axis=-1, keepdims=True) + 1e-8
)
# Compute cosine similarity
similarity = jnp.sum(image_features * text_features, axis=-1)
return similarity
Experiments to Try¤
1. Architecture Improvements¤
- Add attention mechanisms for better feature aggregation
- Use pre-trained encoders (ResNet for images, BERT for text)
- Implement transformer-based fusion layers
- Add residual connections
2. Training Enhancements¤
- Implement contrastive loss (InfoNCE, SimCLR)
- Add hard negative mining
- Use temperature-scaled training
- Implement data augmentation
3. Advanced Features¤
- Multi-head attention for fusion
- Cross-modal attention mechanisms
- Hierarchical embeddings
- Multi-task learning objectives
Next Steps¤
-
Multimodal Guide
Compare this standalone walkthrough with the retained modality abstractions
-
Image Modality
Understand the retained image-processing path used by Artifex models
-
Text Modality
Review the retained token and text-preparation surfaces
-
Planned Topics
Track the still-unshipped CLIP, VQA, captioning, and retrieval examples
Troubleshooting¤
Common Issues¤
Embedding Dimension Mismatch:
- Ensure both encoders output same embedding dimension
- Check fusion layer input dimensions
- Verify concatenation axis
Poor Similarity Scores:
- Normalize features before computing similarity
- Check for numerical instability (add epsilon)
- Tune temperature parameter
Memory Issues:
- Reduce batch size or embedding dimensions
- Use gradient checkpointing
- Enable mixed precision training
Additional Resources¤
Documentation¤
Research Papers¤
- Learning Transferable Visual Models From Natural Language Supervision (CLIP)
- ALIGN: Scaling Up Visual and Vision-Language Representation Learning
- Multimodal Learning with Transformers: A Survey
Author: Artifex Team Last Updated: 2025-10-22 Difficulty: Intermediate Time to Complete: ~45 minutes