[Attention] Move MLA forward from backend to layer (#33284)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-30 22:30:00 -05:00
committed by GitHub
parent 010ec0c30e
commit aaa901ad55
13 changed files with 753 additions and 535 deletions

View File

@@ -67,7 +67,7 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def get_impl_cls() -> type["AttentionImpl"]:
def get_impl_cls() -> type["AttentionImplBase"]:
raise NotImplementedError
@staticmethod
@@ -594,7 +594,14 @@ class AttentionLayer(Protocol):
) -> torch.Tensor: ...
class AttentionImpl(ABC, Generic[T]):
class AttentionImplBase(ABC, Generic[T]):
"""Base class for attention implementations.
Contains common attributes and initialization logic shared by both
standard AttentionImpl and MLAAttentionImpl. Does not define a forward
method - subclasses define their own forward interfaces.
"""
# Required attributes that all impls should have
num_heads: int
head_size: int
@@ -662,6 +669,13 @@ class AttentionImpl(ABC, Generic[T]):
)
return self
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass
class AttentionImpl(AttentionImplBase[T], Generic[T]):
"""Standard attention implementation with forward method."""
@abstractmethod
def __init__(
self,
@@ -704,11 +718,10 @@ class AttentionImpl(ABC, Generic[T]):
"""
return False
def process_weights_after_loading(self, act_dtype: torch.dtype):
pass
class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""MLA attention implementation with forward_mqa and forward_mha methods."""
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod
def __init__(
self,
@@ -731,22 +744,78 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
v_head_dim: int,
kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
q_pad_num_heads: int | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
def forward_mha(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
k_scale: torch.Tensor,
output: torch.Tensor,
) -> None:
"""MHA-style prefill forward pass."""
raise NotImplementedError
@abstractmethod
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""MQA-style decode forward pass."""
raise NotImplementedError
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
"""Sparse MLA attention implementation with only forward_mqa method.
Sparse MLA implementations only support decode (MQA-style) attention.
They do not support prefill (MHA-style) attention.
"""
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None,
attn_type: str,
kv_sharing_target_layer_name: str | None,
# MLA Specific Arguments
q_lora_rank: int | None,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
q_pad_num_heads: int | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T,
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""MQA-style decode forward pass."""
raise NotImplementedError

View File

@@ -244,7 +244,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
return out, lse
def _forward_decode(
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,

View File

@@ -293,7 +293,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
)
def _forward_decode(
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,

View File

@@ -150,7 +150,7 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None
def _forward_decode(
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,

View File

@@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
def _forward_decode(
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,

View File

@@ -11,7 +11,6 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims,
)
from vllm.platforms import current_platform
@@ -25,6 +24,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
@@ -686,7 +686,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
return metadata
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
@staticmethod
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
# FP8 decode kernel only supports h_q = 64 or 128
@@ -710,19 +710,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
indexer: "Indexer | None" = None,
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
@@ -974,78 +967,39 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
output = output[:, : self.num_heads, :]
return output
def forward(
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: FlashMLASparseMetadata | None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor | None]:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided."
# Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for MLACommonImpl"
)
num_actual_toks = q.shape[0]
if attn_metadata is None:
# Dummy run - no need to allocate buffers
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
# Get topk indices
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if not use_fp8_cache:
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata)
attn_out = self._forward_bf16_kv(
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
elif attn_metadata.fp8_use_mixed_batch:
attn_out = self._forward_fp8_kv_mixed_batch(
q, kv_cache, topk_indices, attn_metadata
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
else:
attn_out = self._forward_fp8_kv_separate_prefill_decode(
q, kv_cache, topk_indices, attn_metadata
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
)
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output
return attn_out, None

View File

@@ -241,7 +241,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
return output
def _forward_decode(
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,

View File

@@ -7,12 +7,10 @@ from typing import TYPE_CHECKING, ClassVar
import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims,
)
from vllm.triton_utils import tl, triton
@@ -23,6 +21,7 @@ from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
@@ -269,7 +268,7 @@ def reference_mla_sparse_prefill(
return (result, lse)
class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]):
def __init__(
self,
num_heads: int,
@@ -287,23 +286,15 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
indexer: "Indexer | None" = None,
**mla_args,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**mla_args,
)
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def _forward_bf16_kv(
self,
@@ -342,56 +333,23 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
return output[:, : self.num_heads, :]
def forward(
def forward_mqa(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: ROCMAiterMLASparseMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
layer: AttentionLayer,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
assert output is not None, "Output tensor must be provided."
# Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for ROCMAiterMLASparse"
)
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
if self.is_fp8bmm_enabled:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
ql_nope = rocm_aiter_ops.triton_fp8_bmm(
q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
else:
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
num_actual_toks = q.shape[0]
# Get topk indices
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
@@ -403,22 +361,8 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
)
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
attn_out = self._forward_bf16_kv(
q, kv_cache, topk_indices_global, attn_metadata
q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
)
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output
return attn_out, None

View File

@@ -110,7 +110,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
**kwargs,
)
def _forward_decode(
def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
kv_c_and_k_pe_cache: torch.Tensor,