[SM100] Resubmit FMHA FP8 prefill for MLA (#31195)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2026-02-10 13:18:43 -08:00
committed by GitHub
parent 9615575afc
commit 578977bb5e
3 changed files with 145 additions and 23 deletions

View File

@@ -27,7 +27,7 @@ from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import MLAAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
AttentionBackendEnum.CUTLASS_MLA, AttentionBackendEnum.CUTLASS_MLA,
@@ -512,7 +512,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
def run_attention_backend( def run_attention_backend(
backend: AttentionBackendEnum, backend: AttentionBackendEnum,
kv_cache_spec: FullAttentionSpec, kv_cache_spec: MLAAttentionSpec,
layer_names: list[str], layer_names: list[str],
vllm_config, vllm_config,
device: torch.device, device: torch.device,
@@ -989,7 +989,7 @@ def test_backend_correctness(
kv_cache = kv_cache_per_block_size[block_size] kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend # Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec( backend_kv_cache_spec = MLAAttentionSpec(
block_size=block_size, block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads( num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config vllm_config.parallel_config
@@ -997,6 +997,7 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(), head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype, dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(), sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
) )
backend_output = run_attention_backend( backend_output = run_attention_backend(

View File

@@ -43,6 +43,9 @@ class AttentionConfig:
disable_flashinfer_q_quantization: bool = False disable_flashinfer_q_quantization: bool = False
"""If set, when using fp8 kv, do not quantize Q to fp8.""" """If set, when using fp8 kv, do not quantize Q to fp8."""
use_prefill_query_quantization: bool = False
"""If set, quantize query for attention in prefill."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
Provide a hash that uniquely identifies all the configs Provide a hash that uniquely identifies all the configs

View File

@@ -1052,6 +1052,7 @@ class MLACommonPrefillMetadata:
query_seq_lens: torch.Tensor | None = None query_seq_lens: torch.Tensor | None = None
workspace_buffer: torch.Tensor | None = None workspace_buffer: torch.Tensor | None = None
q_data_type: torch.dtype | None = None q_data_type: torch.dtype | None = None
output_dtype: torch.dtype | None = None
@dataclass @dataclass
@@ -1145,6 +1146,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128 return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128
@functools.cache
def use_flashinfer_prefill() -> bool: def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since # For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2. # it is faster than FA2.
@@ -1162,6 +1164,7 @@ def use_flashinfer_prefill() -> bool:
return is_deepseek_r1_mla_compatible(vllm_config) return is_deepseek_r1_mla_compatible(vllm_config)
@functools.cache
def use_cudnn_prefill() -> bool: def use_cudnn_prefill() -> bool:
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
@@ -1174,6 +1177,7 @@ def use_cudnn_prefill() -> bool:
) )
@functools.cache
def use_trtllm_ragged_deepseek_prefill() -> bool: def use_trtllm_ragged_deepseek_prefill() -> bool:
"""Check if TRT-LLM ragged DeepSeek prefill should be used.""" """Check if TRT-LLM ragged DeepSeek prefill should be used."""
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
@@ -1210,6 +1214,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims:
) )
@functools.cache
def backend_supports_prefill_query_quantization() -> bool:
"""Check if the selected MLA backend supports prefill query quantization.
Currently supported backends:
- FlashInfer prefill
- TRT-LLM ragged DeepSeek prefill
Not supported:
- cuDNN Prefill
- FlashAttention
- Non-GB200 devices (FP8 prefill requires device capability 100)
"""
# FP8 prefill query quantization requires GB200 (device capability 100)
# for the necessary FP8 kernels at the moment.
if not current_platform.is_device_capability_family(100):
return False
return use_flashinfer_prefill() or use_trtllm_ragged_deepseek_prefill()
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
@@ -1262,6 +1287,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return chunked_prefill_workspace_size return chunked_prefill_workspace_size
@staticmethod
def determine_prefill_query_data_type(
vllm_config: VllmConfig,
model_dtype: torch.dtype,
) -> torch.dtype:
"""
Determine the query data type for prefill queries.
Return FP8 dtype if cache is FP8 and prefill query quantization
is enabled, else model dtype.
"""
use_fp8 = (
vllm_config.cache_config.cache_dtype.startswith("fp8")
and vllm_config.attention_config.use_prefill_query_quantization
and backend_supports_prefill_query_quantization()
)
if use_fp8:
fp8_dtype = current_platform.fp8_dtype()
logger.info_once(
"FP8 prefill attention enabled: query data type is FP8", scope="local"
)
return fp8_dtype
elif vllm_config.attention_config.use_prefill_query_quantization:
logger.info_once(
"Unable to perform FP8 prefill attention when"
" use_prefill_query_quantization is enabled. Please"
" ensure that --kv-cache-dtype is set to fp8 and your prefill"
" backend is compatible with FP8 attention.",
scope="local",
)
return model_dtype
return model_dtype
def __init__( def __init__(
self, self,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
@@ -1285,6 +1344,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config) self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda() self.aot_schedule = current_platform.is_cuda()
self.kv_cache_spec = kv_cache_spec
self.q_data_type = self.determine_prefill_query_data_type(
vllm_config, self.model_config.dtype
)
try: try:
self.dcp_world_size = get_dcp_group().world_size self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group self.dcp_rank = get_dcp_group().rank_in_group
@@ -1325,7 +1390,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.chunked_prefill_workspace_size, self.chunked_prefill_workspace_size,
self.model_config.get_head_size(), self.model_config.get_head_size(),
), ),
dtype=self.model_config.dtype, dtype=self.q_data_type,
device=device, device=device,
) )
@@ -1435,7 +1500,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale, sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left, window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap, logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype, q_data_type=self.q_data_type,
o_data_type=prefill.output_dtype,
) )
# Prepare context prefills # Prepare context prefills
@@ -1454,7 +1520,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale, sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left, window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap, logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype, q_data_type=self.q_data_type,
o_data_type=prefill.output_dtype,
) )
prefill.prefill_main = self._fi_prefill_main prefill.prefill_main = self._fi_prefill_main
@@ -1709,6 +1776,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc, query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len, max_query_len=max_query_len,
chunked_context=chunked_context_metadata, chunked_context=chunked_context_metadata,
output_dtype=self.model_config.dtype,
q_data_type=self.q_data_type,
) )
if self._use_cudnn_prefill: if self._use_cudnn_prefill:
@@ -1894,7 +1963,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.indexer = indexer self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads self.q_pad_num_heads = q_pad_num_heads
self.supports_quant_query_input = True self.supports_quant_query_input = True
# Use flashinfer's optimized concat_mla_k kernel when available. # Use flashinfer's optimized concat_mla_k kernel when available.
@@ -2129,6 +2197,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert prefill.query_seq_lens is not None assert prefill.query_seq_lens is not None
assert prefill.workspace_buffer is not None assert prefill.workspace_buffer is not None
# allocate BF16 / FP16 output tensor for TRT-LLM ragged attention
out = torch.empty(
q.shape[0],
q.shape[1],
v.shape[2],
device=q.device,
dtype=prefill.output_dtype,
)
ret = trtllm_ragged_attention_deepseek( ret = trtllm_ragged_attention_deepseek(
query=q, query=q,
@@ -2148,6 +2224,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
enable_pdl=False, enable_pdl=False,
is_causal=True, is_causal=True,
return_lse=return_softmax_lse, return_lse=return_softmax_lse,
out=out,
) )
if isinstance(ret, tuple): if isinstance(ret, tuple):
@@ -2170,7 +2247,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q.shape[1], q.shape[1],
v.shape[2], v.shape[2],
device=q.device, device=q.device,
dtype=q.dtype, dtype=prefill.output_dtype,
) )
prefill.workspace_buffer.fill_(0) prefill.workspace_buffer.fill_(0)
@@ -2240,11 +2317,18 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None assert prefill_metadata.chunked_context is not None
use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
output = None output = None
iters = len(prefill_metadata.chunked_context.seq_tot) iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace workspace = prefill_metadata.chunked_context.workspace
if use_fp8_prefill:
q = q.to(prefill_metadata.q_data_type)
for i in range(iters): for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i] toks = prefill_metadata.chunked_context.seq_tot[i]
if not use_fp8_prefill:
ops.gather_and_maybe_dequant_cache( ops.gather_and_maybe_dequant_cache(
src_cache=kv_c_and_k_pe_cache, src_cache=kv_c_and_k_pe_cache,
dst=workspace, dst=workspace,
@@ -2256,13 +2340,36 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale=k_scale, scale=k_scale,
seq_starts=prefill_metadata.chunked_context.starts[i], seq_starts=prefill_metadata.chunked_context.starts[i],
) )
else:
# FP8 path: gather cache without dequantization
ops.cp_gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
# Extract kv_c_normed from workspace
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) # When FP8 weights are used without FP8 prefill, kv_b_proj expects
# model dtype input and will quantize internally.
if (
use_fp8_prefill
or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype()
):
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype)
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
) )
# To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope.
if use_fp8_prefill:
kv_nope = kv_nope.to(prefill_metadata.q_data_type)
k_pe = k_pe.to(prefill_metadata.q_data_type)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe) k = self._concat_k_nope_k_pe(k_nope, k_pe)
@@ -2412,16 +2519,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
assert self.dcp_world_size != -1 assert self.dcp_world_size != -1
has_context = attn_metadata.prefill.chunked_context is not None prefill_metadata = attn_metadata.prefill
use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
# Convert q to FP8 if FP8 prefill attention is enabled
if use_fp8_prefill:
q = q.to(prefill_metadata.q_data_type)
has_context = prefill_metadata.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
) )
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = self._concat_k_nope_k_pe(k_nope, k_pe) k = self._concat_k_nope_k_pe(k_nope, k_pe)
if use_fp8_prefill:
k = k.to(prefill_metadata.q_data_type)
v = v.to(prefill_metadata.q_data_type)
output_prefill = self._run_prefill_new_tokens( output_prefill = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill, prefill=prefill_metadata,
q=q, q=q,
k=k, k=k,
v=v, v=v,