[FlexAttention] allow custom mask mod (#37692)
Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
@@ -3,9 +3,10 @@
|
||||
"""Attention layer with FlexAttention."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._dynamo.decorators
|
||||
@@ -294,6 +295,27 @@ def causal_mask_mod(
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
# Type alias for the block sparsity hint callable signature.
|
||||
_block_sparsity_hint_signature = Callable[
|
||||
[torch.Tensor, torch.Tensor, int], torch.Tensor
|
||||
]
|
||||
|
||||
|
||||
class BlockSparsityHint(NamedTuple):
|
||||
"""This prunes KV blocks from the BlockMask before the flex_attention kernel
|
||||
is invoked, so that blocks that are fully masked never get loaded.
|
||||
Use this with custom mask_mods that are sparse to avoid
|
||||
the kernel iterating over all KV blocks unnecessarily.
|
||||
|
||||
Attributes:
|
||||
hint_fn: (q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks],
|
||||
block_size int) -> bool Tensor [num_tokens, num_kv_blocks].
|
||||
Returns True for block pairs that may contain non-masked elements.
|
||||
"""
|
||||
|
||||
hint_fn: _block_sparsity_hint_signature
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlexAttentionMetadata:
|
||||
causal: bool
|
||||
@@ -335,6 +357,7 @@ class FlexAttentionMetadata:
|
||||
transformed_score_mod: _score_mod_signature | None = None
|
||||
sliding_window: int | None = None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
block_sparsity_hint: BlockSparsityHint | None = None
|
||||
|
||||
@cached_property
|
||||
def logical_block_ids(self):
|
||||
@@ -378,7 +401,7 @@ class FlexAttentionMetadata:
|
||||
|
||||
return is_valid, logical_q_idx, logical_kv_idx
|
||||
|
||||
def get_causal_mask_mod(self) -> _mask_mod_signature:
|
||||
def get_paged_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the mask_mod function for FlexAttention.
|
||||
|
||||
This function creates the combined mask mod function that handles:
|
||||
@@ -504,8 +527,9 @@ class FlexAttentionMetadata:
|
||||
def get_mask_mod(self):
|
||||
# Stage-1: initialize the base mask_mod
|
||||
# (causal mask for decoder or bidirectional mask for encoder)
|
||||
if self.causal:
|
||||
mask_mod = self.get_causal_mask_mod()
|
||||
has_custom_mask = self.logical_mask_mod is not causal_mask_mod
|
||||
if self.causal or has_custom_mask:
|
||||
mask_mod = self.get_paged_mask_mod()
|
||||
else:
|
||||
mask_mod = self.get_bidirectional_mask_mod()
|
||||
# stage-2: add external mask_mod for special attention during
|
||||
@@ -591,7 +615,9 @@ class FlexAttentionMetadata:
|
||||
self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
|
||||
]
|
||||
|
||||
if self.sliding_window and self.causal:
|
||||
custom_hint = self.block_sparsity_hint is not None
|
||||
|
||||
if self.sliding_window or custom_hint:
|
||||
device = used_pages.device
|
||||
assert self.doc_ids is not None
|
||||
token_indices = torch.arange(
|
||||
@@ -602,10 +628,24 @@ class FlexAttentionMetadata:
|
||||
- self.query_start_loc[self.doc_ids]
|
||||
+ self.decode_offset[self.doc_ids]
|
||||
)
|
||||
min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0)
|
||||
min_block_idx = min_kv_idx // self.block_size
|
||||
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
|
||||
used_pages.masked_fill_(~sliding_mask, 0)
|
||||
|
||||
if self.sliding_window:
|
||||
assert self.sliding_window is not None
|
||||
min_kv_idx = torch.clamp(
|
||||
logical_q_idx - (self.sliding_window - 1), min=0
|
||||
)
|
||||
min_block_idx = min_kv_idx // self.block_size
|
||||
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
|
||||
used_pages.masked_fill_(~sliding_mask, 0)
|
||||
if custom_hint:
|
||||
assert self.block_sparsity_hint is not None
|
||||
q_block_idx = logical_q_idx // self.block_size
|
||||
hint_mask = self.block_sparsity_hint.hint_fn(
|
||||
q_block_idx[:, None],
|
||||
self.logical_block_ids[None, :],
|
||||
self.block_size,
|
||||
)
|
||||
used_pages.masked_fill_(~hint_mask, 0)
|
||||
|
||||
used_pages_padded = pad_to_multiple(
|
||||
used_pages, multiple=self.q_block_size, dim=0
|
||||
@@ -660,11 +700,6 @@ class FlexAttentionMetadata:
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
self.transformed_score_mod = self.get_transformed_score_mod()
|
||||
|
||||
if self.direct_build and self.causal:
|
||||
self.block_mask = self._build_block_mask_direct()
|
||||
else:
|
||||
self.block_mask = self.build_block_mask()
|
||||
|
||||
|
||||
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
def __init__(
|
||||
@@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
alibi_slopes: torch.Tensor | None
|
||||
logits_soft_cap: float | None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
logical_mask_mod: _mask_mod_signature | None = None
|
||||
block_sparsity_hint: BlockSparsityHint | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -907,8 +944,25 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||
needs_rebuild_block_mask = True
|
||||
|
||||
if needs_rebuild_block_mask:
|
||||
if attn_metadata.direct_build and attn_metadata.causal:
|
||||
layer_mask_mod = getattr(layer, "logical_mask_mod", None)
|
||||
if (
|
||||
layer_mask_mod is not None
|
||||
and attn_metadata.logical_mask_mod is not layer_mask_mod
|
||||
):
|
||||
attn_metadata.logical_mask_mod = layer_mask_mod
|
||||
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||
needs_rebuild_block_mask = True
|
||||
|
||||
layer_hint = getattr(layer, "block_sparsity_hint", None)
|
||||
if (
|
||||
layer_hint is not None
|
||||
and attn_metadata.block_sparsity_hint is not layer_hint
|
||||
):
|
||||
attn_metadata.block_sparsity_hint = layer_hint
|
||||
needs_rebuild_block_mask = True
|
||||
|
||||
if needs_rebuild_block_mask or attn_metadata.block_mask is None:
|
||||
if attn_metadata.direct_build:
|
||||
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
|
||||
else:
|
||||
attn_metadata.block_mask = attn_metadata.build_block_mask()
|
||||
|
||||
Reference in New Issue
Block a user