Skip to content

Flash Attention¤

Module: generative_models.core.layers.flash_attention

Source: generative_models/core/layers/flash_attention.py

Overview¤

Flash-attention-style helpers for Flax NNX.

This page documents the retained single JAX fallback implementation. It does not publish backend switches, Triton-specific runtime guarantees, or broader performance claims beyond the code that actually ships in this repository.

Based on:

Classes¤

AttentionMask¤

class AttentionMask

FlashAttentionConfig¤

class FlashAttentionConfig

FlashMultiHeadAttention¤

class FlashMultiHeadAttention

Functions¤

call¤

def __call__()

init¤

def __init__()

create_attention_mask¤

def create_attention_mask()

flash_attention¤

def flash_attention()

init_cache¤

def init_cache()

Module Statistics¤

  • Classes: 3
  • Functions: 5