Text Modality Guide¤
This guide covers working with text data in Artifex, including tokenization, vocabulary management, text datasets, and best practices for text-based generative models.
Overview¤
Artifex's text modality provides a unified interface for processing text data, handling tokenization, vocabulary management, and sequence processing for generative models.
TextModality is the preprocessing owner for tokenization, detokenization, and sequence shaping. It is not a standalone text generator; model-owned generation stays outside this helper surface. Public evaluation lives in TextEvaluationSuite, and representation processing lives in TextProcessor and TokenizationProcessor.
-
Tokenization
Built-in tokenization with special token handling (BOS, EOS, PAD, UNK)
-
Vocabulary Management
Configurable vocabulary size and token mapping
-
Sequence Handling
Padding, truncation, and sequence length management
-
Synthetic Datasets
Ready-to-use synthetic text datasets for testing
-
Text Augmentation
Token masking, replacement, and sequence augmentation
-
JAX-Native
Full JAX compatibility with efficient batch processing
Text Configuration¤
Basic Configuration¤
from artifex.configs import ModalityConfig
from artifex.generative_models.modalities import TextModality
from flax import nnx
# Initialize RNG
rngs = nnx.Rngs(0)
# Configure text modality
text_config = ModalityConfig(
name="text",
modality_name="text",
metadata={
"text_params": {
"vocab_size": 10000,
"max_length": 512,
"pad_token_id": 0,
"unk_token_id": 1,
"bos_token_id": 2,
"eos_token_id": 3,
"case_sensitive": False
}
}
)
# Create modality
text_modality = TextModality(config=text_config, rngs=rngs)
# Access configuration
print(f"Vocab size: {text_modality.vocab_size}") # 10000
print(f"Max length: {text_modality.max_length}") # 512
Special Tokens¤
Artifex uses standard special tokens for sequence processing:
| Token | ID | Purpose |
|---|---|---|
| PAD | 0 | Padding token for variable-length sequences |
| UNK | 1 | Unknown token for out-of-vocabulary words |
| BOS | 2 | Beginning-of-sequence marker |
| EOS | 3 | End-of-sequence marker |
# Special token configuration
text_config = ModalityConfig(
name="text",
modality_name="text",
metadata={
"text_params": {
"vocab_size": 50000,
"max_length": 1024,
"pad_token_id": 0, # Padding
"unk_token_id": 1, # Unknown
"bos_token_id": 2, # Beginning
"eos_token_id": 3, # End
"case_sensitive": True # Preserve case
}
}
)
Text Datasets¤
Synthetic Text Datasets¤
Artifex provides several synthetic text dataset types:
Random Sentences¤
from artifex.generative_models.modalities.text.datasets import create_text_dataset
# Create dataset with random sentences
random_text = create_text_dataset(
"synthetic",
rngs=rngs,
dataset_size=5000,
pattern_type="random_sentences",
vocab_size=10000,
max_length=512,
)
# Iterate over samples
for sample in random_text:
print(sample["text_tokens"].shape) # (512,) - padded to max_length
break
Generated patterns:
- Simple subject-verb-adverb sentences
- Random selection from vocabulary
- Natural-looking structure
Repeated Phrases¤
# Dataset with repeated phrases
repeated_text = create_text_dataset(
"synthetic",
rngs=rngs,
dataset_size=5000,
pattern_type="repeated_phrases",
vocab_size=10000,
max_length=512,
)
# Iterate to inspect
for sample in repeated_text:
print(sample["text_tokens"].shape)
break
Useful for:
- Testing sequence models
- Pattern recognition
- Repetition detection
Numerical Sequences¤
# Dataset with numerical sequences
sequences = create_text_dataset(
"synthetic",
rngs=rngs,
dataset_size=5000,
pattern_type="sequences",
vocab_size=10000,
max_length=512,
)
# Iterate to inspect
for sample in sequences:
print(sample["text_tokens"].shape)
break
Useful for:
- Sequence learning tasks
- Arithmetic operations
- Ordering and counting
Palindromes¤
# Dataset with palindromic patterns
palindromes = create_text_dataset(
"synthetic",
rngs=rngs,
dataset_size=5000,
pattern_type="palindromes",
vocab_size=10000,
max_length=512,
)
# Iterate to inspect
for sample in palindromes:
print(sample["text_tokens"].shape)
break
Useful for:
- Reversibility testing
- Symmetry detection
- Pattern recognition
Simple Text Datasets¤
For custom text data:
from artifex.generative_models.modalities.text.datasets import create_text_dataset
# Your text data
texts = [
"The quick brown fox jumps over the lazy dog",
"Machine learning is a subset of artificial intelligence",
"Deep learning uses neural networks with multiple layers",
"Natural language processing enables text understanding",
"Transformers revolutionized NLP with attention mechanisms",
]
# Create dataset
text_dataset = create_text_dataset(
"simple",
rngs=rngs,
texts=texts,
vocab_size=10000,
max_length=512,
)
# Iterate over samples
for sample in text_dataset:
print(f"Tokens: {sample['text_tokens'].shape}")
print(f"Index: {sample['index']}")
break
Factory Function¤
The create_text_dataset() factory returns a MemorySource instance. It accepts
dataset_type as the first positional argument ("synthetic" or "simple"),
rngs as a required keyword argument, and forwards remaining keyword arguments
to the underlying data generation functions.
from artifex.generative_models.modalities.text.datasets import create_text_dataset
# Create synthetic dataset
dataset = create_text_dataset(
"synthetic",
rngs=rngs,
dataset_size=10000,
pattern_type="random_sentences",
vocab_size=10000,
max_length=512,
)
# Create simple dataset
custom_dataset = create_text_dataset(
"simple",
rngs=rngs,
texts=["text 1", "text 2", "text 3"],
vocab_size=10000,
max_length=512,
)
Tokenization¤
Basic Tokenization¤
Artifex's text datasets use simple hash-based tokenization:
def tokenize_text(text: str, config) -> jax.Array:
"""Tokenize text to token IDs.
Args:
text: Input text string
config: Text configuration
Returns:
Token sequence (max_length,)
"""
# Get parameters from config
text_params = config.metadata.get("text_params", {})
vocab_size = text_params.get("vocab_size", 10000)
max_length = text_params.get("max_length", 512)
pad_token_id = text_params.get("pad_token_id", 0)
bos_token_id = text_params.get("bos_token_id", 2)
eos_token_id = text_params.get("eos_token_id", 3)
case_sensitive = text_params.get("case_sensitive", False)
# Normalize case
if not case_sensitive:
text = text.lower()
# Split into words
words = text.strip().split()
# Convert to tokens
tokens = [bos_token_id] # Add BOS
for word in words:
# Simple hash-based token ID
token_id = hash(word) % (vocab_size - 4) + 4
tokens.append(token_id)
tokens.append(eos_token_id) # Add EOS
# Pad or truncate
if len(tokens) > max_length:
tokens = tokens[:max_length]
else:
tokens.extend([pad_token_id] * (max_length - len(tokens)))
return jnp.array(tokens, dtype=jnp.int32)
# Usage
text = "Hello world, this is a test"
tokens = tokenize_text(text, text_config)
print(tokens.shape) # (512,)
print(tokens[:10]) # [2, 5234, 8761, 1234, 9876, 4321, 6543, 3, 0, 0]
Detokenization¤
def detokenize_tokens(tokens: jax.Array, config) -> str:
"""Convert tokens back to text.
Args:
tokens: Token sequence
config: Text configuration
Returns:
Detokenized text string
"""
text_params = config.metadata.get("text_params", {})
pad_token_id = text_params.get("pad_token_id", 0)
bos_token_id = text_params.get("bos_token_id", 2)
eos_token_id = text_params.get("eos_token_id", 3)
# Convert to list
token_list = tokens.tolist()
# Remove special tokens
filtered_tokens = []
for token in token_list:
if token in [pad_token_id, bos_token_id, eos_token_id]:
if token == eos_token_id:
break # Stop at EOS
continue
filtered_tokens.append(token)
# Convert back to words (placeholder - in practice use vocabulary)
words = [f"token_{token}" for token in filtered_tokens]
return " ".join(words)
# Usage
text = "Hello world"
tokens = tokenize_text(text, text_config)
recovered = detokenize_tokens(tokens, text_config)
print(recovered)
Custom Tokenizers¤
For production use, integrate real tokenizers:
import jax.numpy as jnp
from datarax.sources import MemorySource, MemorySourceConfig
from flax import nnx
# Tokenize with a custom tokenizer, then wrap in MemorySource
texts = ["hello world", "machine learning", "deep neural networks"]
max_length = 128
# Replace with your real tokenizer (e.g., HuggingFace)
token_arrays = []
for text in texts:
# encoded = tokenizer.encode(text, max_length=max_length)
encoded = jnp.array([2, 100, 200, 300, 3] + [0] * (max_length - 5))
token_arrays.append(encoded)
data = {
"text_tokens": jnp.stack(token_arrays),
"index": jnp.arange(len(texts)),
}
config = MemorySourceConfig(shuffle=True)
dataset = MemorySource(config, data, rngs=nnx.Rngs(0))
batch = dataset.get_batch(2)
print(batch["text_tokens"].shape) # (2, 128)
Text Preprocessing¤
Padding and Truncation¤
import jax.numpy as jnp
def pad_sequence(tokens: jax.Array, max_length: int, pad_token_id: int = 0):
"""Pad token sequence to max_length.
Args:
tokens: Token sequence
max_length: Target length
pad_token_id: Padding token ID
Returns:
Padded sequence
"""
current_length = len(tokens)
if current_length >= max_length:
return tokens[:max_length]
padding = jnp.full((max_length - current_length,), pad_token_id, dtype=tokens.dtype)
return jnp.concatenate([tokens, padding])
def truncate_sequence(tokens: jax.Array, max_length: int, eos_token_id: int = 3):
"""Truncate sequence and add EOS token.
Args:
tokens: Token sequence
max_length: Maximum length
eos_token_id: EOS token ID
Returns:
Truncated sequence with EOS
"""
if len(tokens) <= max_length:
return tokens
# Truncate and add EOS
truncated = tokens[:max_length-1]
return jnp.concatenate([truncated, jnp.array([eos_token_id])])
# Usage
tokens = jnp.array([2, 100, 200, 300, 400, 3]) # BOS ... EOS
# Pad to 10
padded = pad_sequence(tokens, max_length=10, pad_token_id=0)
print(padded) # [2, 100, 200, 300, 400, 3, 0, 0, 0, 0]
# Truncate to 5
truncated = truncate_sequence(tokens, max_length=5, eos_token_id=3)
print(truncated) # [2, 100, 200, 300, 3]
Batch Padding¤
def pad_batch(token_sequences: list[jax.Array], pad_token_id: int = 0):
"""Pad batch of sequences to same length.
Args:
token_sequences: List of token sequences
pad_token_id: Padding token ID
Returns:
Padded batch (batch_size, max_length)
"""
# Find maximum length
max_length = max(len(seq) for seq in token_sequences)
# Pad all sequences
padded = []
for seq in token_sequences:
padded_seq = pad_sequence(seq, max_length, pad_token_id)
padded.append(padded_seq)
return jnp.stack(padded)
# Usage
sequences = [
jnp.array([2, 100, 200, 3]),
jnp.array([2, 300, 400, 500, 600, 3]),
jnp.array([2, 700, 3])
]
batch = pad_batch(sequences, pad_token_id=0)
print(batch.shape) # (3, 6) - padded to longest sequence
print(batch)
# [[ 2 100 200 3 0 0]
# [ 2 300 400 500 600 3]
# [ 2 700 3 0 0 0]]
Attention Masks¤
def create_attention_mask(tokens: jax.Array, pad_token_id: int = 0):
"""Create attention mask for padded sequences.
Args:
tokens: Token sequence with padding
pad_token_id: Padding token ID
Returns:
Attention mask (1 for real tokens, 0 for padding)
"""
return (tokens != pad_token_id).astype(jnp.int32)
def create_causal_mask(seq_length: int):
"""Create causal mask for autoregressive generation.
Args:
seq_length: Sequence length
Returns:
Causal mask (seq_length, seq_length)
"""
mask = jnp.tril(jnp.ones((seq_length, seq_length)))
return mask
# Usage
tokens = jnp.array([2, 100, 200, 300, 3, 0, 0, 0])
# Padding mask
pad_mask = create_attention_mask(tokens, pad_token_id=0)
print(pad_mask) # [1 1 1 1 1 0 0 0]
# Causal mask for generation
causal_mask = create_causal_mask(8)
print(causal_mask)
# [[1 0 0 0 0 0 0 0]
# [1 1 0 0 0 0 0 0]
# [1 1 1 0 0 0 0 0]
# ...
# [1 1 1 1 1 1 1 1]]
Positional Embeddings¤
Artifex provides multiple positional encoding methods for transformer-based models.
Learned Position Embeddings¤
The default approach using learnable position embeddings:
from artifex.generative_models.extensions.nlp.embeddings import TextEmbeddings
from artifex.generative_models.core.configuration import ExtensionConfig
from flax import nnx
rngs = nnx.Rngs(0)
# Configure embeddings
config = ExtensionConfig(
weight=1.0,
enabled=True,
extensions={
"embeddings": {
"embedding_dim": 512,
"vocab_size": 50000,
"max_position_embeddings": 1024,
"dropout_rate": 0.1,
"use_position_embeddings": True
}
}
)
# Create embedding module
embeddings = TextEmbeddings(config=config, rngs=rngs)
# Embed tokens with learned positions
tokens = jnp.array([[2, 100, 200, 300, 3]]) # [batch, seq_len]
embedded = embeddings.embed(tokens, deterministic=True)
print(embedded.shape) # (1, 5, 512)
Rotary Position Embeddings (RoPE)¤
RoPE is the state-of-the-art positional encoding used in modern LLMs like Llama 2. It encodes position through rotation of embedding vectors:
# Embed with RoPE (Rotary Position Embeddings)
embedded_rope = embeddings.embed_with_rope(
tokens,
deterministic=True,
base=10000.0 # RoPE base frequency
)
print(embedded_rope.shape) # (1, 5, 512)
# Apply RoPE to existing embeddings
raw_embeddings = embeddings.get_token_embeddings(tokens[0])
rotated = embeddings.apply_rope_embeddings(raw_embeddings[None], base=10000.0)
Key benefits of RoPE:
- Enables relative position attention patterns
- Better length generalization
- No learned parameters for positions
- Used in Llama 2, PaLM, and other modern LLMs
Sinusoidal Position Embeddings¤
Fixed positional encodings from the original Transformer paper "Attention is All You Need":
# Embed with sinusoidal positions
embedded_sin = embeddings.embed_with_sinusoidal_positions(
tokens,
deterministic=True,
base=10000.0
)
print(embedded_sin.shape) # (1, 5, 512)
# Get raw sinusoidal encodings
sin_encodings = embeddings.get_sinusoidal_embeddings(
seq_len=100,
dim=512,
base=10000.0
)
print(sin_encodings.shape) # (100, 512)
Formula:
Standalone RoPE Functions¤
For custom implementations, use the standalone utility functions:
from artifex.generative_models.extensions.nlp.embeddings import (
precompute_rope_freqs,
apply_rope,
create_sinusoidal_positions
)
# Precompute RoPE frequencies
freqs_sin, freqs_cos = precompute_rope_freqs(
dim=64, # Must be even
max_seq_len=512,
base=10000.0
)
# Apply to query/key tensors in attention
q = jnp.ones((2, 8, 128, 64)) # [batch, heads, seq, dim]
k = jnp.ones((2, 8, 128, 64))
q_rotated = apply_rope(q, freqs_sin, freqs_cos)
k_rotated = apply_rope(k, freqs_sin, freqs_cos)
# Create standalone sinusoidal positions
positions = create_sinusoidal_positions(
max_seq_len=1024,
dim=512
)
Text Augmentation¤
Token Masking¤
import jax
import jax.numpy as jnp
def mask_tokens(tokens: jax.Array, key, mask_prob: float = 0.15, mask_token_id: int = 1):
"""Randomly mask tokens (BERT-style).
Args:
tokens: Token sequence
key: Random key
mask_prob: Probability of masking
mask_token_id: Token ID for masked positions
Returns:
Masked tokens, original tokens
"""
# Create mask (don't mask special tokens)
special_tokens = jnp.array([0, 1, 2, 3]) # PAD, UNK, BOS, EOS
is_special = jnp.isin(tokens, special_tokens)
# Random mask
mask = jax.random.bernoulli(key, mask_prob, tokens.shape)
mask = mask & (~is_special) # Don't mask special tokens
# Apply mask
masked_tokens = jnp.where(mask, mask_token_id, tokens)
return masked_tokens, tokens
# Usage
tokens = jnp.array([2, 100, 200, 300, 400, 3, 0, 0])
key = jax.random.key(0)
masked, original = mask_tokens(tokens, key, mask_prob=0.15)
print("Original:", original)
print("Masked: ", masked)
Token Replacement¤
def replace_tokens(
tokens: jax.Array,
key,
replace_prob: float = 0.1,
vocab_size: int = 10000
):
"""Randomly replace tokens with random tokens.
Args:
tokens: Token sequence
key: Random key
replace_prob: Probability of replacement
vocab_size: Vocabulary size
Returns:
Augmented tokens
"""
keys = jax.random.split(key, 2)
# Create replacement mask (don't replace special tokens)
special_tokens = jnp.array([0, 1, 2, 3])
is_special = jnp.isin(tokens, special_tokens)
mask = jax.random.bernoulli(keys[0], replace_prob, tokens.shape)
mask = mask & (~is_special)
# Generate random tokens (from vocab, excluding special tokens)
random_tokens = jax.random.randint(keys[1], tokens.shape, 4, vocab_size)
# Apply replacement
augmented = jnp.where(mask, random_tokens, tokens)
return augmented
# Usage
tokens = jnp.array([2, 100, 200, 300, 400, 3, 0, 0])
key = jax.random.key(0)
augmented = replace_tokens(tokens, key, replace_prob=0.1, vocab_size=10000)
print("Original: ", tokens)
print("Augmented: ", augmented)
Sequence Shuffling¤
def shuffle_tokens(
tokens: jax.Array,
key,
shuffle_prob: float = 0.1
):
"""Randomly shuffle tokens within a window.
Args:
tokens: Token sequence
key: Random key
shuffle_prob: Probability of shuffling each position
Returns:
Shuffled tokens
"""
# Don't shuffle special tokens
special_tokens = jnp.array([0, 1, 2, 3])
is_special = jnp.isin(tokens, special_tokens)
# For simplicity, shuffle entire sequence
should_shuffle = jax.random.bernoulli(key, shuffle_prob)
def do_shuffle(t):
# Extract non-special tokens
non_special_mask = ~is_special
non_special_tokens = t[non_special_mask]
# Shuffle
shuffled_key = jax.random.key(0)
shuffled = jax.random.permutation(shuffled_key, non_special_tokens)
# Put back
result = t.copy()
result = jnp.where(non_special_mask, shuffled, t)
return result
result = jax.lax.cond(
should_shuffle,
do_shuffle,
lambda t: t,
tokens
)
return result
# Usage
tokens = jnp.array([2, 100, 200, 300, 400, 3, 0, 0])
key = jax.random.key(0)
shuffled = shuffle_tokens(tokens, key, shuffle_prob=0.5)
print("Original: ", tokens)
print("Shuffled: ", shuffled)
Complete Augmentation Pipeline¤
@jax.jit
def augment_text(tokens: jax.Array, key, vocab_size: int = 10000):
"""Apply complete text augmentation.
Args:
tokens: Token sequence
key: Random key
vocab_size: Vocabulary size
Returns:
Augmented tokens
"""
keys = jax.random.split(key, 3)
# Token masking (15%)
tokens, _ = mask_tokens(tokens, keys[0], mask_prob=0.15)
# Token replacement (5%)
tokens = replace_tokens(tokens, keys[1], replace_prob=0.05, vocab_size=vocab_size)
# Note: Shuffling typically not used with masking
# tokens = shuffle_tokens(tokens, keys[2], shuffle_prob=0.05)
return tokens
# Batch augmentation
def augment_text_batch(token_batch: jax.Array, key, vocab_size: int = 10000):
"""Augment batch of text sequences.
Args:
token_batch: Batch of token sequences (N, max_length)
key: Random key
vocab_size: Vocabulary size
Returns:
Augmented batch
"""
batch_size = token_batch.shape[0]
keys = jax.random.split(key, batch_size)
# Vectorize over batch
augmented = jax.vmap(lambda t, k: augment_text(t, k, vocab_size))(
token_batch, keys
)
return augmented
# Usage in training
key = jax.random.key(0)
for batch in data_loader:
key, subkey = jax.random.split(key)
augmented_tokens = augment_text_batch(
batch["text_tokens"],
subkey,
vocab_size=10000
)
# Use augmented_tokens for training
Vocabulary Statistics¤
Computing Statistics¤
def compute_vocab_stats(dataset):
"""Compute vocabulary statistics for dataset.
Args:
dataset: Text dataset
Returns:
Dictionary of statistics
"""
all_tokens = set()
sequence_lengths = []
token_frequencies = {}
for sample in dataset:
tokens = sample["text_tokens"]
# Collect unique tokens
all_tokens.update(tokens.tolist())
# Sequence length (excluding padding)
pad_token_id = 0
length = jnp.sum(tokens != pad_token_id)
sequence_lengths.append(int(length))
# Token frequencies
for token in tokens:
token = int(token)
if token != pad_token_id:
token_frequencies[token] = token_frequencies.get(token, 0) + 1
return {
"unique_tokens": len(all_tokens),
"vocab_coverage": len(all_tokens) / dataset.vocab_size,
"avg_sequence_length": jnp.mean(jnp.array(sequence_lengths)),
"max_sequence_length": max(sequence_lengths),
"min_sequence_length": min(sequence_lengths),
"total_tokens": sum(token_frequencies.values()),
"most_common": sorted(token_frequencies.items(), key=lambda x: x[1], reverse=True)[:10]
}
# Usage
stats = compute_vocab_stats(text_dataset)
print(f"Unique tokens: {stats['unique_tokens']}")
print(f"Vocab coverage: {stats['vocab_coverage']:.2%}")
print(f"Avg sequence length: {stats['avg_sequence_length']:.1f}")
print(f"Most common tokens: {stats['most_common']}")
Complete Examples¤
Example 1: Text Generation Dataset¤
import jax
import jax.numpy as jnp
from flax import nnx
from artifex.generative_models.modalities.text.datasets import create_text_dataset
# Setup
rngs = nnx.Rngs(0)
# Create datasets via factory
train_dataset = create_text_dataset(
"synthetic",
rngs=rngs,
dataset_size=100000,
pattern_type="random_sentences",
vocab_size=50000,
max_length=256,
)
val_dataset = create_text_dataset(
"synthetic",
rngs=rngs,
dataset_size=10000,
pattern_type="random_sentences",
vocab_size=50000,
max_length=256,
)
# Training loop
key = jax.random.key(42)
for epoch in range(10):
for batch in train_dataset:
tokens = batch["text_tokens"]
# Apply augmentation during training
key, subkey = jax.random.split(key)
augmented_tokens = augment_text_batch(tokens[None], subkey, vocab_size=50000)
# Training step
# loss = train_step(model, augmented_tokens)
# Validation (no augmentation)
for val_batch in val_dataset:
pass
# val_loss = validate_step(model, val_batch["text_tokens"])
print(f"Epoch {epoch + 1}/10 complete")
Example 2: Custom Text Dataset¤
import jax.numpy as jnp
from datarax.sources import MemorySource, MemorySourceConfig
from flax import nnx
from artifex.generative_models.modalities.text.datasets import simple_tokenize
# Load texts from file
# In practice: texts = Path("data/texts.txt").read_text().splitlines()
texts = ["Sample text 1", "Sample text 2", "Sample text 3"]
# Tokenize all texts
vocab_size = 10000
max_length = 512
token_arrays = []
for text in texts:
tokens = simple_tokenize(text, vocab_size=vocab_size, max_length=max_length)
token_arrays.append(tokens)
# Wrap in MemorySource
data = {
"text_tokens": jnp.stack(token_arrays),
"index": jnp.arange(len(texts)),
}
config = MemorySourceConfig(shuffle=True)
dataset = MemorySource(config, data, rngs=nnx.Rngs(0))
# Usage
batch = dataset.get_batch(2)
print(batch["text_tokens"].shape) # (2, 512)
Best Practices¤
DO¤
Tokenization
- Use consistent tokenization across train/val/test splits
- Handle special tokens properly (BOS, EOS, PAD, UNK)
- Choose appropriate vocabulary size for your task
- Preserve case if semantically important
- Validate tokenized sequences
- Cache tokenized data when possible
Sequence Handling
- Pad sequences to consistent length for batching
- Use attention masks to handle padding
- Truncate long sequences appropriately
- Add BOS/EOS tokens for generation tasks
- Handle variable-length sequences efficiently
Augmentation
- Apply augmentation only during training
- Don't mask special tokens
- Balance augmentation strength
- Use JIT compilation for speed
- Validate augmented sequences
DON'T¤
Common Mistakes
- Mix different tokenization schemes
- Forget to add special tokens
- Ignore padding in loss computation
- Apply augmentation during validation
- Use case-sensitive when not needed
- Exceed vocabulary size with token IDs
Performance Issues
- Tokenize on-the-fly during training
- Use Python loops for token processing
- Load entire corpus into memory
- Recompute masks every forward pass
Data Quality
- Skip sequence validation
- Mix different sequence lengths without padding
- Use inconsistent special token IDs
- Ignore out-of-vocabulary tokens
Summary¤
This guide covered:
- Text configuration - Vocabulary, sequence length, special tokens
- Text datasets - Synthetic and custom text datasets
- Tokenization - Token mapping, padding, truncation
- Preprocessing - Attention masks, batch padding
- Positional embeddings - Learned, RoPE, and sinusoidal encoding methods
- Augmentation - Token masking, replacement, shuffling
- Vocabulary stats - Computing coverage and frequency
- Complete examples - Training pipelines and custom datasets
- Best practices - DOs and DON'Ts for text data
Next Steps¤
-
Learn about audio processing, spectrograms, and audio augmentation
-
Working with multiple modalities and aligned multi-modal datasets
-
Deep dive into image datasets, preprocessing, and augmentation
-
Complete API documentation for all dataset classes and functions