[Attention] FlashAttn MLA (#14258)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-09-04 05:47:59 -04:00
committed by GitHub
parent 2c301ee2eb
commit 402759d472
22 changed files with 480 additions and 200 deletions

View File

@@ -317,7 +317,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len

View File

@@ -52,8 +52,9 @@ class LinearAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
attn_metadata = LinearAttentionMetadata(
num_prefills=num_prefills,

View File

@@ -50,8 +50,9 @@ class Mamba1AttentionMetadataBuilder(
query_start_loc.device)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
has_initial_states = None
padded_decodes = num_decodes

View File

@@ -115,8 +115,9 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:

View File

@@ -578,11 +578,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
seq_lens=seq_lens_device,
)
def build_for_cudagraph_capture(
@@ -618,6 +620,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
@@ -625,7 +628,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_seq_lens_cpu)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata)
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
@@ -725,7 +729,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
)
attn_metadata = self.metadata_cls(

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
get_flash_attn_version)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
@staticmethod
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
return FlashAttnMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
return FlashAttnMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
return FlashAttnMLAImpl
@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
query_start_loc: torch.Tensor
max_query_len: int
max_seq_len: int
scheduler_metadata: Optional[torch.Tensor] = None
@dataclass
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
pass
class FlashAttnMLAMetadataBuilder(
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
reorder_batch_threshold: ClassVar[int] = 512
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashAttnMLAMetadata)
self.fa_aot_schedule = (get_flash_attn_version() == 3)
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.fa_aot_schedule:
return get_scheduler_metadata(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads,
num_heads_kv=1,
headdim=self.mla_dims.qk_rope_head_dim,
cache_seqlens=seqlens,
qkv_dtype=self.kv_cache_spec.dtype,
headdim_v=self.mla_dims.kv_lora_rank,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
)
return None
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
cu_query_lens=query_start_loc_device,
max_query_len=max_query_len,
seqlens=seq_lens_device,
max_seq_len=max_seq_len,
causal=True,
)
return FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
query_start_loc=query_start_loc_device,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
)
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert flash_attn_supports_mla(), \
"FlashAttnMLA is not supported on this device"
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashAttnMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttnMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttnMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"FP8 FlashAttention MLA not yet supported")
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
o = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
q_v=q_nope,
max_seqlen_q=attn_metadata.decode.max_query_len,
cu_seqlens_q=attn_metadata.decode.query_start_loc,
max_seqlen_k=attn_metadata.decode.max_seq_len,
seqused_k=attn_metadata.decode.seq_lens,
block_table=attn_metadata.decode.block_table,
softmax_scale=self.scale,
causal=True,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
)
return self._v_up_proj(o)

View File

@@ -85,11 +85,13 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
device=self.device,
dtype=torch.int32)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
seq_lens_device,
self.num_q_heads,
1, # MQA for the decode path
)
@@ -123,7 +125,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
)

View File

@@ -104,12 +104,14 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
dtype=torch.int32,
device=device)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device
num_reqs = seq_lens.size(0)
num_reqs = seq_lens_device.size(0)
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
@@ -117,7 +119,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = seq_lens_device % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
@@ -156,7 +158,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
seq_lens=seq_lens_device,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,

View File

@@ -58,8 +58,9 @@ class ShortConvAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
has_initial_states = None
if num_prefills > 0:
#[batch,]
@@ -78,4 +79,4 @@ class ShortConvAttentionMetadataBuilder(
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata
return attn_metadata

View File

@@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional
import torch
@@ -197,6 +197,8 @@ class XFormersAttentionMetadata:
class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
@@ -212,9 +214,10 @@ class XFormersAttentionMetadataBuilder(
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def build(
self,
@@ -223,8 +226,9 @@ class XFormersAttentionMetadataBuilder(
fast_build: bool = False,
) -> XFormersAttentionMetadata:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc