Support encoder_only attention for FlexAttention (#22273)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
committed by
GitHub
parent
41b67f4263
commit
f825c6bd22
@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
|
||||
|
||||
@dataclass
|
||||
class FlexAttentionMetadata:
|
||||
causal: bool
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
@@ -177,10 +178,9 @@ class FlexAttentionMetadata:
|
||||
num_blocks = 0
|
||||
block_mask: Optional[BlockMask] = None
|
||||
score_mod: Optional[_score_mod_signature] = None
|
||||
mask_mod: Optional[_mask_mod_signature] = None
|
||||
logical_mask_mod: _mask_mod_signature = causal_mask_mod
|
||||
|
||||
def get_mask_mod(self) -> _mask_mod_signature:
|
||||
def get_causal_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the mask_mod function for FlexAttention.
|
||||
|
||||
This function creates the combined mask mod function that handles:
|
||||
@@ -233,14 +233,39 @@ class FlexAttentionMetadata:
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the encoder mask_mod function for FlexAttention.
|
||||
|
||||
Since the encoder bidirectional attention doesn't run with
|
||||
KV cache, this function creates a mask based on the
|
||||
packed query sequences.
|
||||
"""
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return request_lookup[q_idx] == request_lookup[kv_idx]
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def build_block_mask(self) -> BlockMask:
|
||||
assert self.mask_mod is not None
|
||||
if self.causal:
|
||||
mask_mod = self.get_causal_mask_mod()
|
||||
kv_len = self.total_cache_tokens
|
||||
else:
|
||||
mask_mod = self.get_bidirectional_mask_mod()
|
||||
kv_len = self.num_actual_tokens
|
||||
return create_block_mask_compiled(
|
||||
self.mask_mod,
|
||||
mask_mod,
|
||||
None,
|
||||
None,
|
||||
self.num_actual_tokens,
|
||||
self.total_cache_tokens,
|
||||
kv_len,
|
||||
device=self.block_table.device,
|
||||
)
|
||||
|
||||
@@ -251,7 +276,6 @@ class FlexAttentionMetadata:
|
||||
assert self.prefix_kv_lens is None, "Not implemented yet."
|
||||
assert self.suffix_kv_lens is None, "Not implemented yet."
|
||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
self.block_mask = self.build_block_mask()
|
||||
|
||||
|
||||
@@ -306,6 +330,7 @@ class FlexAttentionMetadataBuilder(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
causal=common_attn_metadata.causal,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
@@ -350,6 +375,12 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.attn_type = attn_type
|
||||
|
||||
if attn_type not in (AttentionType.ENCODER_ONLY,
|
||||
AttentionType.DECODER):
|
||||
raise NotImplementedError(
|
||||
f"FlexAttention does not support {attn_type} attention")
|
||||
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError(
|
||||
@@ -425,26 +456,38 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if not attn_metadata.causal:
|
||||
assert self.attn_type == AttentionType.ENCODER_ONLY
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
query, key_tensor, value_tensor = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key, value),
|
||||
)
|
||||
|
||||
else:
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads,
|
||||
self.head_size)
|
||||
query, key_tensor, value_tensor = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
query, key_cache, value_cache = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
@@ -465,8 +508,8 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
|
||||
out = flex_attention_compiled(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attn_metadata.score_mod,
|
||||
attn_metadata.block_mask,
|
||||
self.scale,
|
||||
|
||||
Reference in New Issue
Block a user