[Attention][FlashInfer] Enable FP8 FlashInfer (TRTLLM) MLA decode (#24705)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Matthew Bonanni
2025-09-12 17:45:53 -04:00
committed by GitHub
parent c89ed8de43
commit 7ba32aa60b
8 changed files with 23 additions and 10 deletions

View File

@@ -584,7 +584,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
# Prepare context prefills
@@ -605,7 +604,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
logits_soft_cap=self._global_hyperparameters.
logits_soft_cap,
q_data_type=self.model_config.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
prefill.prefill_main = self._fi_prefill_main

View File

@@ -6,8 +6,7 @@ from typing import Optional, Union
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
@@ -69,11 +68,9 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"FlashInferMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashInferMLA V1 with FP8 KV cache not yet supported")
self._workspace_buffer = g_fi_workspace
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None
def _forward_decode(
self,
@@ -92,6 +89,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
# trtllm API requires extra dimension q_len_per_request for MTP
q = q.unsqueeze(1)
if self.bmm1_scale is None:
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
self.scale)
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
@@ -102,7 +105,8 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.scale,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
# TODO: Return LSE pending support from Flashinfer API: