[Kernel] Add FP8 support with FlashMLA backend (#22668)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
@@ -631,8 +631,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
if self.aot_schedule:
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
# `context_chunk_starts` that are not aligned to page_size
|
||||
# currently the `gather_and_maybe_dequant_cache` kernel
|
||||
# cannot handle `context_chunk_starts` that are not aligned
|
||||
# to page_size
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.page_size)
|
||||
|
||||
@@ -1005,6 +1006,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
):
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
@@ -1017,12 +1019,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
ops.gather_cache(
|
||||
ops.gather_and_maybe_dequant_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
|
||||
@@ -1073,6 +1077,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
@@ -1095,7 +1100,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
if has_context:
|
||||
suffix_output, suffix_lse = output
|
||||
context_output, context_lse = self._compute_prefill_context( \
|
||||
q, kv_c_and_k_pe_cache, attn_metadata)
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||
|
||||
output = torch.empty_like(suffix_output)
|
||||
merge_attn_states(
|
||||
@@ -1119,6 +1124,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1146,6 +1152,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
@@ -1180,10 +1188,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if fp8_attention:
|
||||
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
if has_prefill:
|
||||
output[num_decode_tokens:] = self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
attn_metadata, layer._k_scale)
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
@@ -1196,7 +1207,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
if fp8_attention:
|
||||
ql_nope_shape = decode_ql_nope.shape
|
||||
decode_ql_nope, _ = ops.scaled_fp8_quant(
|
||||
decode_ql_nope.reshape([
|
||||
ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]
|
||||
]), layer._q_scale)
|
||||
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
|
||||
q_pe_shape = decode_q_pe.shape
|
||||
decode_q_pe, _ = ops.scaled_fp8_quant(
|
||||
decode_q_pe.reshape(
|
||||
[q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
|
||||
layer._q_scale)
|
||||
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
|
||||
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
|
||||
|
||||
return output_padded
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import ClassVar, Optional
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
@@ -278,6 +278,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
if self._use_old_cutlass_mla:
|
||||
# TODO: Remove the old cutlass MLA kernel after more extensive
|
||||
|
||||
@@ -6,8 +6,7 @@ from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
@@ -166,16 +165,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
@@ -194,6 +190,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import ClassVar, Optional
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
@@ -221,6 +222,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
@@ -127,6 +127,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
Reference in New Issue
Block a user