[SM100] Resubmit FMHA FP8 prefill for MLA (#31195)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from vllm.v1.attention.backend import CommonAttentionMetadata
|
|||||||
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
|
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
|
||||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
from vllm.v1.kv_cache_interface import MLAAttentionSpec
|
||||||
|
|
||||||
BACKENDS_TO_TEST = [
|
BACKENDS_TO_TEST = [
|
||||||
AttentionBackendEnum.CUTLASS_MLA,
|
AttentionBackendEnum.CUTLASS_MLA,
|
||||||
@@ -512,7 +512,7 @@ class MockMLAAttentionLayer(AttentionLayerBase):
|
|||||||
|
|
||||||
def run_attention_backend(
|
def run_attention_backend(
|
||||||
backend: AttentionBackendEnum,
|
backend: AttentionBackendEnum,
|
||||||
kv_cache_spec: FullAttentionSpec,
|
kv_cache_spec: MLAAttentionSpec,
|
||||||
layer_names: list[str],
|
layer_names: list[str],
|
||||||
vllm_config,
|
vllm_config,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@@ -989,7 +989,7 @@ def test_backend_correctness(
|
|||||||
kv_cache = kv_cache_per_block_size[block_size]
|
kv_cache = kv_cache_per_block_size[block_size]
|
||||||
|
|
||||||
# Create kv_cache_spec with the correct block_size for this backend
|
# Create kv_cache_spec with the correct block_size for this backend
|
||||||
backend_kv_cache_spec = FullAttentionSpec(
|
backend_kv_cache_spec = MLAAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||||
vllm_config.parallel_config
|
vllm_config.parallel_config
|
||||||
@@ -997,6 +997,7 @@ def test_backend_correctness(
|
|||||||
head_size=vllm_config.model_config.get_head_size(),
|
head_size=vllm_config.model_config.get_head_size(),
|
||||||
dtype=vllm_config.model_config.dtype,
|
dtype=vllm_config.model_config.dtype,
|
||||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||||
|
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
backend_output = run_attention_backend(
|
backend_output = run_attention_backend(
|
||||||
|
|||||||
@@ -43,6 +43,9 @@ class AttentionConfig:
|
|||||||
disable_flashinfer_q_quantization: bool = False
|
disable_flashinfer_q_quantization: bool = False
|
||||||
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
||||||
|
|
||||||
|
use_prefill_query_quantization: bool = False
|
||||||
|
"""If set, quantize query for attention in prefill."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
Provide a hash that uniquely identifies all the configs
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
|||||||
@@ -1052,6 +1052,7 @@ class MLACommonPrefillMetadata:
|
|||||||
query_seq_lens: torch.Tensor | None = None
|
query_seq_lens: torch.Tensor | None = None
|
||||||
workspace_buffer: torch.Tensor | None = None
|
workspace_buffer: torch.Tensor | None = None
|
||||||
q_data_type: torch.dtype | None = None
|
q_data_type: torch.dtype | None = None
|
||||||
|
output_dtype: torch.dtype | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -1145,6 +1146,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
|
|||||||
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128
|
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
def use_flashinfer_prefill() -> bool:
|
def use_flashinfer_prefill() -> bool:
|
||||||
# For blackwell default to flashinfer prefill if it's available since
|
# For blackwell default to flashinfer prefill if it's available since
|
||||||
# it is faster than FA2.
|
# it is faster than FA2.
|
||||||
@@ -1162,6 +1164,7 @@ def use_flashinfer_prefill() -> bool:
|
|||||||
return is_deepseek_r1_mla_compatible(vllm_config)
|
return is_deepseek_r1_mla_compatible(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
def use_cudnn_prefill() -> bool:
|
def use_cudnn_prefill() -> bool:
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
@@ -1174,6 +1177,7 @@ def use_cudnn_prefill() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
def use_trtllm_ragged_deepseek_prefill() -> bool:
|
def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||||
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
|
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
@@ -1210,6 +1214,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def backend_supports_prefill_query_quantization() -> bool:
|
||||||
|
"""Check if the selected MLA backend supports prefill query quantization.
|
||||||
|
|
||||||
|
Currently supported backends:
|
||||||
|
- FlashInfer prefill
|
||||||
|
- TRT-LLM ragged DeepSeek prefill
|
||||||
|
|
||||||
|
Not supported:
|
||||||
|
- cuDNN Prefill
|
||||||
|
- FlashAttention
|
||||||
|
- Non-GB200 devices (FP8 prefill requires device capability 100)
|
||||||
|
"""
|
||||||
|
# FP8 prefill query quantization requires GB200 (device capability 100)
|
||||||
|
# for the necessary FP8 kernels at the moment.
|
||||||
|
if not current_platform.is_device_capability_family(100):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return use_flashinfer_prefill() or use_trtllm_ragged_deepseek_prefill()
|
||||||
|
|
||||||
|
|
||||||
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||||
"""
|
"""
|
||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
@@ -1262,6 +1287,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
return chunked_prefill_workspace_size
|
return chunked_prefill_workspace_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def determine_prefill_query_data_type(
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
model_dtype: torch.dtype,
|
||||||
|
) -> torch.dtype:
|
||||||
|
"""
|
||||||
|
Determine the query data type for prefill queries.
|
||||||
|
Return FP8 dtype if cache is FP8 and prefill query quantization
|
||||||
|
is enabled, else model dtype.
|
||||||
|
"""
|
||||||
|
use_fp8 = (
|
||||||
|
vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||||
|
and vllm_config.attention_config.use_prefill_query_quantization
|
||||||
|
and backend_supports_prefill_query_quantization()
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_fp8:
|
||||||
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
logger.info_once(
|
||||||
|
"FP8 prefill attention enabled: query data type is FP8", scope="local"
|
||||||
|
)
|
||||||
|
return fp8_dtype
|
||||||
|
elif vllm_config.attention_config.use_prefill_query_quantization:
|
||||||
|
logger.info_once(
|
||||||
|
"Unable to perform FP8 prefill attention when"
|
||||||
|
" use_prefill_query_quantization is enabled. Please"
|
||||||
|
" ensure that --kv-cache-dtype is set to fp8 and your prefill"
|
||||||
|
" backend is compatible with FP8 attention.",
|
||||||
|
scope="local",
|
||||||
|
)
|
||||||
|
return model_dtype
|
||||||
|
|
||||||
|
return model_dtype
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kv_cache_spec: AttentionSpec,
|
kv_cache_spec: AttentionSpec,
|
||||||
@@ -1285,6 +1344,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||||
self.mla_dims = get_mla_dims(self.model_config)
|
self.mla_dims = get_mla_dims(self.model_config)
|
||||||
self.aot_schedule = current_platform.is_cuda()
|
self.aot_schedule = current_platform.is_cuda()
|
||||||
|
|
||||||
|
self.kv_cache_spec = kv_cache_spec
|
||||||
|
self.q_data_type = self.determine_prefill_query_data_type(
|
||||||
|
vllm_config, self.model_config.dtype
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.dcp_world_size = get_dcp_group().world_size
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
self.dcp_rank = get_dcp_group().rank_in_group
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
@@ -1325,7 +1390,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
self.chunked_prefill_workspace_size,
|
self.chunked_prefill_workspace_size,
|
||||||
self.model_config.get_head_size(),
|
self.model_config.get_head_size(),
|
||||||
),
|
),
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.q_data_type,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1435,7 +1500,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
sm_scale=self._global_hyperparameters.sm_scale,
|
sm_scale=self._global_hyperparameters.sm_scale,
|
||||||
window_left=self._global_hyperparameters.window_left,
|
window_left=self._global_hyperparameters.window_left,
|
||||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||||
q_data_type=self.model_config.dtype,
|
q_data_type=self.q_data_type,
|
||||||
|
o_data_type=prefill.output_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare context prefills
|
# Prepare context prefills
|
||||||
@@ -1454,7 +1520,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
sm_scale=self._global_hyperparameters.sm_scale,
|
sm_scale=self._global_hyperparameters.sm_scale,
|
||||||
window_left=self._global_hyperparameters.window_left,
|
window_left=self._global_hyperparameters.window_left,
|
||||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||||
q_data_type=self.model_config.dtype,
|
q_data_type=self.q_data_type,
|
||||||
|
o_data_type=prefill.output_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefill.prefill_main = self._fi_prefill_main
|
prefill.prefill_main = self._fi_prefill_main
|
||||||
@@ -1709,6 +1776,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
query_start_loc=prefill_query_start_loc,
|
query_start_loc=prefill_query_start_loc,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
chunked_context=chunked_context_metadata,
|
chunked_context=chunked_context_metadata,
|
||||||
|
output_dtype=self.model_config.dtype,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._use_cudnn_prefill:
|
if self._use_cudnn_prefill:
|
||||||
@@ -1894,7 +1963,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.indexer = indexer
|
self.indexer = indexer
|
||||||
self.q_pad_num_heads = q_pad_num_heads
|
self.q_pad_num_heads = q_pad_num_heads
|
||||||
|
|
||||||
self.supports_quant_query_input = True
|
self.supports_quant_query_input = True
|
||||||
|
|
||||||
# Use flashinfer's optimized concat_mla_k kernel when available.
|
# Use flashinfer's optimized concat_mla_k kernel when available.
|
||||||
@@ -2129,6 +2197,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
|
|
||||||
assert prefill.query_seq_lens is not None
|
assert prefill.query_seq_lens is not None
|
||||||
assert prefill.workspace_buffer is not None
|
assert prefill.workspace_buffer is not None
|
||||||
|
# allocate BF16 / FP16 output tensor for TRT-LLM ragged attention
|
||||||
|
out = torch.empty(
|
||||||
|
q.shape[0],
|
||||||
|
q.shape[1],
|
||||||
|
v.shape[2],
|
||||||
|
device=q.device,
|
||||||
|
dtype=prefill.output_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
ret = trtllm_ragged_attention_deepseek(
|
ret = trtllm_ragged_attention_deepseek(
|
||||||
query=q,
|
query=q,
|
||||||
@@ -2148,6 +2224,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
enable_pdl=False,
|
enable_pdl=False,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
return_lse=return_softmax_lse,
|
return_lse=return_softmax_lse,
|
||||||
|
out=out,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(ret, tuple):
|
if isinstance(ret, tuple):
|
||||||
@@ -2170,7 +2247,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
q.shape[1],
|
q.shape[1],
|
||||||
v.shape[2],
|
v.shape[2],
|
||||||
device=q.device,
|
device=q.device,
|
||||||
dtype=q.dtype,
|
dtype=prefill.output_dtype,
|
||||||
)
|
)
|
||||||
prefill.workspace_buffer.fill_(0)
|
prefill.workspace_buffer.fill_(0)
|
||||||
|
|
||||||
@@ -2240,11 +2317,18 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = attn_metadata.prefill
|
||||||
assert prefill_metadata.chunked_context is not None
|
assert prefill_metadata.chunked_context is not None
|
||||||
|
|
||||||
|
use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
|
||||||
|
|
||||||
output = None
|
output = None
|
||||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||||
workspace = prefill_metadata.chunked_context.workspace
|
workspace = prefill_metadata.chunked_context.workspace
|
||||||
|
|
||||||
|
if use_fp8_prefill:
|
||||||
|
q = q.to(prefill_metadata.q_data_type)
|
||||||
|
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
|
if not use_fp8_prefill:
|
||||||
ops.gather_and_maybe_dequant_cache(
|
ops.gather_and_maybe_dequant_cache(
|
||||||
src_cache=kv_c_and_k_pe_cache,
|
src_cache=kv_c_and_k_pe_cache,
|
||||||
dst=workspace,
|
dst=workspace,
|
||||||
@@ -2256,13 +2340,36 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
scale=k_scale,
|
scale=k_scale,
|
||||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# FP8 path: gather cache without dequantization
|
||||||
|
ops.cp_gather_cache(
|
||||||
|
src_cache=kv_c_and_k_pe_cache,
|
||||||
|
dst=workspace,
|
||||||
|
block_table=prefill_metadata.block_table,
|
||||||
|
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||||
|
batch_size=attn_metadata.num_prefills,
|
||||||
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract kv_c_normed from workspace
|
||||||
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
|
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
|
||||||
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
|
# When FP8 weights are used without FP8 prefill, kv_b_proj expects
|
||||||
|
# model dtype input and will quantize internally.
|
||||||
|
if (
|
||||||
|
use_fp8_prefill
|
||||||
|
or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype()
|
||||||
|
):
|
||||||
|
kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype)
|
||||||
|
|
||||||
|
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
|
||||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope.
|
||||||
|
if use_fp8_prefill:
|
||||||
|
kv_nope = kv_nope.to(prefill_metadata.q_data_type)
|
||||||
|
k_pe = k_pe.to(prefill_metadata.q_data_type)
|
||||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
||||||
@@ -2412,16 +2519,27 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
assert self.dcp_world_size != -1
|
assert self.dcp_world_size != -1
|
||||||
|
|
||||||
has_context = attn_metadata.prefill.chunked_context is not None
|
prefill_metadata = attn_metadata.prefill
|
||||||
|
use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
# Convert q to FP8 if FP8 prefill attention is enabled
|
||||||
|
if use_fp8_prefill:
|
||||||
|
q = q.to(prefill_metadata.q_data_type)
|
||||||
|
|
||||||
|
has_context = prefill_metadata.chunked_context is not None
|
||||||
|
|
||||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
|
||||||
)
|
)
|
||||||
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
k = self._concat_k_nope_k_pe(k_nope, k_pe)
|
||||||
|
|
||||||
|
if use_fp8_prefill:
|
||||||
|
k = k.to(prefill_metadata.q_data_type)
|
||||||
|
v = v.to(prefill_metadata.q_data_type)
|
||||||
|
|
||||||
output_prefill = self._run_prefill_new_tokens(
|
output_prefill = self._run_prefill_new_tokens(
|
||||||
prefill=attn_metadata.prefill,
|
prefill=prefill_metadata,
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
|
|||||||
Reference in New Issue
Block a user