[SM100] Resubmit FMHA FP8 prefill for MLA (#31195)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -1052,6 +1052,7 @@ class MLACommonPrefillMetadata:
|
||||
query_seq_lens: torch.Tensor | None = None
|
||||
workspace_buffer: torch.Tensor | None = None
|
||||
q_data_type: torch.dtype | None = None
|
||||
output_dtype: torch.dtype | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1145,6 +1146,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
|
||||
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128
|
||||
|
||||
|
||||
@functools.cache
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
# For blackwell default to flashinfer prefill if it's available since
|
||||
# it is faster than FA2.
|
||||
@@ -1162,6 +1164,7 @@ def use_flashinfer_prefill() -> bool:
|
||||
return is_deepseek_r1_mla_compatible(vllm_config)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def use_cudnn_prefill() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
@@ -1174,6 +1177,7 @@ def use_cudnn_prefill() -> bool:
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
|
||||
from vllm.config import get_current_vllm_config
|
||||
@@ -1210,6 +1214,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def backend_supports_prefill_query_quantization() -> bool:
|
||||
"""Check if the selected MLA backend supports prefill query quantization.
|
||||
|
||||
Currently supported backends:
|
||||
- FlashInfer prefill
|
||||
- TRT-LLM ragged DeepSeek prefill
|
||||
|
||||
Not supported:
|
||||
- cuDNN Prefill
|
||||
- FlashAttention
|
||||
- Non-GB200 devices (FP8 prefill requires device capability 100)
|
||||
"""
|
||||
# FP8 prefill query quantization requires GB200 (device capability 100)
|
||||
# for the necessary FP8 kernels at the moment.
|
||||
if not current_platform.is_device_capability_family(100):
|
||||
return False
|
||||
|
||||
return use_flashinfer_prefill() or use_trtllm_ragged_deepseek_prefill()
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
@@ -1262,6 +1287,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
return chunked_prefill_workspace_size
|
||||
|
||||
@staticmethod
|
||||
def determine_prefill_query_data_type(
|
||||
vllm_config: VllmConfig,
|
||||
model_dtype: torch.dtype,
|
||||
) -> torch.dtype:
|
||||
"""
|
||||
Determine the query data type for prefill queries.
|
||||
Return FP8 dtype if cache is FP8 and prefill query quantization
|
||||
is enabled, else model dtype.
|
||||
"""
|
||||
use_fp8 = (
|
||||
vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
and vllm_config.attention_config.use_prefill_query_quantization
|
||||
and backend_supports_prefill_query_quantization()
|
||||
)
|
||||
|
||||
if use_fp8:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
logger.info_once(
|
||||
"FP8 prefill attention enabled: query data type is FP8", scope="local"
|
||||
)
|
||||
return fp8_dtype
|
||||
elif vllm_config.attention_config.use_prefill_query_quantization:
|
||||
logger.info_once(
|
||||
"Unable to perform FP8 prefill attention when"
|
||||
" use_prefill_query_quantization is enabled. Please"
|
||||
" ensure that --kv-cache-dtype is set to fp8 and your prefill"
|
||||
" backend is compatible with FP8 attention.",
|
||||
scope="local",
|
||||
)
|
||||
return model_dtype
|
||||
|
||||
return model_dtype
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
@@ -1285,6 +1344,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.q_data_type = self.determine_prefill_query_data_type(
|
||||
vllm_config, self.model_config.dtype
|
||||
)
|
||||
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
@@ -1325,7 +1390,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size(),
|
||||
),
|
||||
dtype=self.model_config.dtype,
|
||||
dtype=self.q_data_type,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -1435,7 +1500,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
sm_scale=self._global_hyperparameters.sm_scale,
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.model_config.dtype,
|
||||
q_data_type=self.q_data_type,
|
||||
o_data_type=prefill.output_dtype,
|
||||
)
|
||||
|
||||
# Prepare context prefills
|
||||
@@ -1454,7 +1520,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
sm_scale=self._global_hyperparameters.sm_scale,
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.model_config.dtype,
|
||||
q_data_type=self.q_data_type,
|
||||
o_data_type=prefill.output_dtype,
|
||||
)
|
||||
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
@@ -1709,6 +1776,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
chunked_context=chunked_context_metadata,
|
||||
output_dtype=self.model_config.dtype,
|
||||
q_data_type=self.q_data_type,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
@@ -1894,7 +1963,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.indexer = indexer
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
|
||||
self.supports_quant_query_input = True
|
||||
|
||||
# Use flashinfer's optimized concat_mla_k kernel when available.
|
||||
@@ -2129,6 +2197,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
assert prefill.query_seq_lens is not None
|
||||
assert prefill.workspace_buffer is not None
|
||||
# allocate BF16 / FP16 output tensor for TRT-LLM ragged attention
|
||||
out = torch.empty(
|
||||
q.shape[0],
|
||||
q.shape[1],
|
||||
v.shape[2],
|
||||
device=q.device,
|
||||
dtype=prefill.output_dtype,
|
||||
)
|
||||
|
||||
ret = trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
@@ -2148,6 +2224,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
enable_pdl=False,
|
||||
is_causal=True,
|
||||
return_lse=return_softmax_lse,
|
||||
out=out,
|
||||
)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
@@ -2170,7 +2247,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
q.shape[1],
|
||||
v.shape[2],
|
||||
device=q.device,
|
||||
dtype=q.dtype,
|
||||
dtype=prefill.output_dtype,
|
||||
)
|
||||
prefill.workspace_buffer.fill_(0)
|
||||
|
||||
@@ -2240,29 +2317,59 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
|
||||
use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
|
||||
|
||||
output = None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
workspace = prefill_metadata.chunked_context.workspace
|
||||
|
||||
if use_fp8_prefill:
|
||||
q = q.to(prefill_metadata.q_data_type)
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
ops.gather_and_maybe_dequant_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
|
||||
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
if not use_fp8_prefill:
|
||||
ops.gather_and_maybe_dequant_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
|
||||
num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=k_scale,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
else:
|
||||
# FP8 path: gather cache without dequantization
|
||||
ops.cp_gather_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
|
||||
# Extract kv_c_normed from workspace
|
||||
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
|
||||
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
|
||||
# When FP8 weights are used without FP8 prefill, kv_b_proj expects
|
||||
# model dtype input and will quantize internally.
|
||||
if (
|
||||
use_fp8_prefill
|
||||
or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype()
|
||||
):
|
||||
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype)
|
||||
|
||||
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
|
||||
)
|
||||
|
||||
# To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope.
|
||||
if use_fp8_prefill:
|
||||
kv_nope = kv_nope.to(prefill_metadata.q_data_type)
|
||||
k_pe = k_pe.to(prefill_metadata.q_data_type)
|
||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
||||
@@ -2412,16 +2519,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
assert attn_metadata.prefill is not None
|
||||
assert self.dcp_world_size != -1
|
||||
|
||||
has_context = attn_metadata.prefill.chunked_context is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
|
||||
|
||||
# Convert q to FP8 if FP8 prefill attention is enabled
|
||||
if use_fp8_prefill:
|
||||
q = q.to(prefill_metadata.q_data_type)
|
||||
|
||||
has_context = prefill_metadata.chunked_context is not None
|
||||
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
|
||||
)
|
||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
||||
|
||||
if use_fp8_prefill:
|
||||
k = k.to(prefill_metadata.q_data_type)
|
||||
v = v.to(prefill_metadata.q_data_type)
|
||||
|
||||
output_prefill = self._run_prefill_new_tokens(
|
||||
prefill=attn_metadata.prefill,
|
||||
prefill=prefill_metadata,
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
|
||||
Reference in New Issue
Block a user