[FlexAttention] allow custom mask mod (#37692)

Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
liangel-02
2026-03-24 16:03:24 -04:00
committed by GitHub
parent 54b0578ada
commit 8c47fdfdb1
2 changed files with 121 additions and 16 deletions

View File

@@ -3,9 +3,10 @@
"""Attention layer with FlexAttention."""
import math
from collections.abc import Callable
from dataclasses import dataclass
from functools import cached_property
from typing import ClassVar
from typing import ClassVar, NamedTuple
import torch
import torch._dynamo.decorators
@@ -294,6 +295,27 @@ def causal_mask_mod(
return q_idx >= kv_idx
# Type alias for the block sparsity hint callable signature.
_block_sparsity_hint_signature = Callable[
[torch.Tensor, torch.Tensor, int], torch.Tensor
]
class BlockSparsityHint(NamedTuple):
"""This prunes KV blocks from the BlockMask before the flex_attention kernel
is invoked, so that blocks that are fully masked never get loaded.
Use this with custom mask_mods that are sparse to avoid
the kernel iterating over all KV blocks unnecessarily.
Attributes:
hint_fn: (q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks],
block_size int) -> bool Tensor [num_tokens, num_kv_blocks].
Returns True for block pairs that may contain non-masked elements.
"""
hint_fn: _block_sparsity_hint_signature
@dataclass
class FlexAttentionMetadata:
causal: bool
@@ -335,6 +357,7 @@ class FlexAttentionMetadata:
transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
block_sparsity_hint: BlockSparsityHint | None = None
@cached_property
def logical_block_ids(self):
@@ -378,7 +401,7 @@ class FlexAttentionMetadata:
return is_valid, logical_q_idx, logical_kv_idx
def get_causal_mask_mod(self) -> _mask_mod_signature:
def get_paged_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles:
@@ -504,8 +527,9 @@ class FlexAttentionMetadata:
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
if self.causal:
mask_mod = self.get_causal_mask_mod()
has_custom_mask = self.logical_mask_mod is not causal_mask_mod
if self.causal or has_custom_mask:
mask_mod = self.get_paged_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()
# stage-2: add external mask_mod for special attention during
@@ -591,7 +615,9 @@ class FlexAttentionMetadata:
self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
]
if self.sliding_window and self.causal:
custom_hint = self.block_sparsity_hint is not None
if self.sliding_window or custom_hint:
device = used_pages.device
assert self.doc_ids is not None
token_indices = torch.arange(
@@ -602,10 +628,24 @@ class FlexAttentionMetadata:
- self.query_start_loc[self.doc_ids]
+ self.decode_offset[self.doc_ids]
)
min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0)
min_block_idx = min_kv_idx // self.block_size
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
used_pages.masked_fill_(~sliding_mask, 0)
if self.sliding_window:
assert self.sliding_window is not None
min_kv_idx = torch.clamp(
logical_q_idx - (self.sliding_window - 1), min=0
)
min_block_idx = min_kv_idx // self.block_size
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
used_pages.masked_fill_(~sliding_mask, 0)
if custom_hint:
assert self.block_sparsity_hint is not None
q_block_idx = logical_q_idx // self.block_size
hint_mask = self.block_sparsity_hint.hint_fn(
q_block_idx[:, None],
self.logical_block_ids[None, :],
self.block_size,
)
used_pages.masked_fill_(~hint_mask, 0)
used_pages_padded = pad_to_multiple(
used_pages, multiple=self.q_block_size, dim=0
@@ -660,11 +700,6 @@ class FlexAttentionMetadata:
self.mask_mod = self.get_mask_mod()
self.transformed_score_mod = self.get_transformed_score_mod()
if self.direct_build and self.causal:
self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask()
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(
@@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl):
alibi_slopes: torch.Tensor | None
logits_soft_cap: float | None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
logical_mask_mod: _mask_mod_signature | None = None
block_sparsity_hint: BlockSparsityHint | None = None
def __init__(
self,
@@ -907,8 +944,25 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
if needs_rebuild_block_mask:
if attn_metadata.direct_build and attn_metadata.causal:
layer_mask_mod = getattr(layer, "logical_mask_mod", None)
if (
layer_mask_mod is not None
and attn_metadata.logical_mask_mod is not layer_mask_mod
):
attn_metadata.logical_mask_mod = layer_mask_mod
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
layer_hint = getattr(layer, "block_sparsity_hint", None)
if (
layer_hint is not None
and attn_metadata.block_sparsity_hint is not layer_hint
):
attn_metadata.block_sparsity_hint = layer_hint
needs_rebuild_block_mask = True
if needs_rebuild_block_mask or attn_metadata.block_mask is None:
if attn_metadata.direct_build:
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()