[Attention] Move MLA forward from backend to layer (#33284)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -67,7 +67,7 @@ class AttentionBackend(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> type["AttentionImpl"]:
|
||||
def get_impl_cls() -> type["AttentionImplBase"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@@ -594,7 +594,14 @@ class AttentionLayer(Protocol):
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
class AttentionImplBase(ABC, Generic[T]):
|
||||
"""Base class for attention implementations.
|
||||
|
||||
Contains common attributes and initialization logic shared by both
|
||||
standard AttentionImpl and MLAAttentionImpl. Does not define a forward
|
||||
method - subclasses define their own forward interfaces.
|
||||
"""
|
||||
|
||||
# Required attributes that all impls should have
|
||||
num_heads: int
|
||||
head_size: int
|
||||
@@ -662,6 +669,13 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
)
|
||||
return self
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
pass
|
||||
|
||||
|
||||
class AttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""Standard attention implementation with forward method."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
@@ -704,11 +718,10 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
"""
|
||||
return False
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
pass
|
||||
|
||||
class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MLA attention implementation with forward_mqa and forward_mha methods."""
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
@@ -731,22 +744,78 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||
v_head_dim: int,
|
||||
kv_b_proj: "ColumnParallelLinear",
|
||||
indexer: object | None = None,
|
||||
q_pad_num_heads: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
def forward_mha(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_cq: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
k_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
"""MHA-style prefill forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""Sparse MLA attention implementation with only forward_mqa method.
|
||||
|
||||
Sparse MLA implementations only support decode (MQA-style) attention.
|
||||
They do not support prefill (MHA-style) attention.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: "ColumnParallelLinear",
|
||||
indexer: object | None = None,
|
||||
q_pad_num_heads: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""MQA-style decode forward pass."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -244,7 +244,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
return out, lse
|
||||
|
||||
def _forward_decode(
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
@@ -293,7 +293,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
@@ -150,7 +150,7 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
def _forward_decode(
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
@@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
"FlashMLAImpl"
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
@@ -11,7 +11,6 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBaseImpl,
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -25,6 +24,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
reshape_attn_output_for_spec_decode,
|
||||
@@ -686,7 +686,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
@staticmethod
|
||||
def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
|
||||
# FP8 decode kernel only supports h_q = 64 or 128
|
||||
@@ -710,19 +710,12 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
@@ -974,78 +967,39 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
def forward(
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata | None,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for MLACommonImpl"
|
||||
)
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
if attn_metadata is None:
|
||||
# Dummy run - no need to allocate buffers
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
# Get topk indices
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if not use_fp8_cache:
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata)
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
elif attn_metadata.fp8_use_mixed_batch:
|
||||
attn_out = self._forward_fp8_kv_mixed_batch(
|
||||
q, kv_cache, topk_indices, attn_metadata
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_out = self._forward_fp8_kv_separate_prefill_decode(
|
||||
q, kv_cache, topk_indices, attn_metadata
|
||||
q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
|
||||
)
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
return attn_out, None
|
||||
|
||||
@@ -241,7 +241,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
@@ -7,12 +7,10 @@ from typing import TYPE_CHECKING, ClassVar
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBaseImpl,
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -23,6 +21,7 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
SparseMLAAttentionImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
@@ -269,7 +268,7 @@ def reference_mla_sparse_prefill(
|
||||
return (result, lse)
|
||||
|
||||
|
||||
class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
class ROCMAiterMLASparseImpl(SparseMLAAttentionImpl[ROCMAiterMLASparseMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -287,23 +286,15 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_lora_rank: int = mla_args["kv_lora_rank"]
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
|
||||
self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
@@ -342,56 +333,23 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
|
||||
return output[:, : self.num_heads, :]
|
||||
|
||||
def forward(
|
||||
def forward_mqa(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: ROCMAiterMLASparseMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for ROCMAiterMLASparse"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
if self.is_fp8bmm_enabled:
|
||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||
ql_nope = rocm_aiter_ops.triton_fp8_bmm(
|
||||
q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
else:
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
# Get topk indices
|
||||
assert self.topk_indices_buffer is not None
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
@@ -403,22 +361,8 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
return attn_out, None
|
||||
|
||||
@@ -110,7 +110,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
def forward_mqa(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user