[Attention][FlashInfer] Enable FP8 FlashInfer (TRTLLM) MLA decode (#24705)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
@@ -584,7 +584,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.model_config.dtype,
|
||||
kv_data_type=self.kv_cache_spec.dtype,
|
||||
)
|
||||
|
||||
# Prepare context prefills
|
||||
@@ -605,7 +604,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
logits_soft_cap=self._global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=self.model_config.dtype,
|
||||
kv_data_type=self.kv_cache_spec.dtype,
|
||||
)
|
||||
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
|
||||
@@ -6,8 +6,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
@@ -69,11 +68,9 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self._workspace_buffer = g_fi_workspace
|
||||
self.bmm1_scale: Optional[float] = None
|
||||
self.bmm2_scale: Optional[float] = None
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
@@ -92,6 +89,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
q = q.unsqueeze(1)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
|
||||
self.scale)
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
@@ -102,7 +105,8 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.scale,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
|
||||
Reference in New Issue
Block a user