[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-08-23 06:09:05 +08:00
committed by GitHub
parent cc7ae5e7ca
commit 24d0c9e6ed
27 changed files with 596 additions and 200 deletions

View File

@@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@@ -497,7 +498,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")

View File

@@ -430,6 +430,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@@ -447,7 +448,7 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

View File

@@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.platforms import current_platform
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention,
@@ -40,6 +41,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
logger = init_logger(__name__)
@@ -653,14 +655,12 @@ class FlashInferImpl(AttentionImpl):
and num_heads % num_kv_heads == 0)
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None
self.o_sf_scale: Optional[float] = None
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
supported_quant_type = (dtype == FP8_DTYPE and static
and group_shape == GroupShape.PER_TENSOR)
def fused_output_quant_supported(self, quant_key: QuantKey):
return (self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8")
and supported_quant_type)
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant))
def forward(
self,
@@ -672,6 +672,7 @@ class FlashInferImpl(AttentionImpl):
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashInfer.
@@ -705,19 +706,32 @@ class FlashInferImpl(AttentionImpl):
if output_scale is None:
assert attn_metadata.q_data_type != FP8_DTYPE, \
"Query can only be FP8 if output fusion happened."
assert output_block_scale is None, "output_block_scale "\
"is not supported when fusion has not happened"
else:
assert attn_metadata.q_data_type == FP8_DTYPE, \
"Query must be FP8 when attn+quant fusion happened."
assert (attn_metadata.prefill_use_trtllm and
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
assert output.dtype == FP8_DTYPE, \
"Output must be FP8 when attn+quant fusion happened."
# TRTLLM attn kernel requires o scale as a host scalar, store the
# o scale to host scalar in warmup run with cuda graph not enabled
if output.dtype == FP8_DTYPE:
assert output_block_scale is None, \
"output_block_scale should not be provided for fp8 output"
elif output.dtype == FP4_DTYPE:
assert output_block_scale is not None, \
"output_block_scale is required for nvfp4 output"
else:
raise ValueError(f"Unsupported output dtype: {output.dtype}")
# TRTLLM attn kernel requires o scale to pass as a host scalar,
# store the o scale as a host scalar in warmup run with cuda graph
# not enabled
if layer._o_scale_float is None:
layer._o_scale_float = output_scale.cpu().item()
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
if output.dtype == FP8_DTYPE:
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float
# Insert FP8 quant for query
num_tokens, num_heads, head_size = query.shape
@@ -818,6 +832,16 @@ class FlashInferImpl(AttentionImpl):
assert block_tables_prefill.is_contiguous()
assert seq_lens_prefill.is_contiguous()
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
out = FP4Tensor(data=output[num_decode_tokens:],
scale=output_block_scale,
scale_start_index=num_decode_tokens,
original_shape=prefill_query.shape)
else:
assert self.o_sf_scale is None
out = output[num_decode_tokens:]
trtllm_batch_context_with_kv_cache(
query=prefill_query,
kv_cache=kv_cache_permute,
@@ -833,7 +857,8 @@ class FlashInferImpl(AttentionImpl):
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
window_left=self.window_left,
sinks=self.sinks,
out=output[num_decode_tokens:],
o_sf_scale=self.o_sf_scale,
out=out,
)
if num_decode_tokens > 0:
@@ -870,6 +895,16 @@ class FlashInferImpl(AttentionImpl):
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
out = FP4Tensor(data=output[:num_decode_tokens],
scale=output_block_scale,
scale_start_index=0,
original_shape=decode_query.shape)
else:
assert self.o_sf_scale is None
out = output[:num_decode_tokens]
trtllm_batch_decode_with_kv_cache(
query=decode_query,
kv_cache=kv_cache_permute,
@@ -881,7 +916,8 @@ class FlashInferImpl(AttentionImpl):
bmm2_scale=self.bmm2_scale,
window_left=self.window_left,
sinks=self.sinks,
out=output[:num_decode_tokens],
o_sf_scale=self.o_sf_scale,
out=out,
)
return output_padded

View File

@@ -428,6 +428,7 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FLexAttention.
@@ -441,7 +442,7 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlexAttentionImpl")

View File

@@ -1138,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLACommonImpl")

View File

@@ -227,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
@@ -239,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl")

View File

@@ -421,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
attn_metadata: AiterFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.
@@ -438,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

View File

@@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl):
attn_metadata: TreeAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with TreeAttention.
@@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TreeAttentionImpl")

View File

@@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TritonAttentionImpl")

View File

@@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl):
attn_metadata: XFormersAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with XFormers.
@@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersAttentionImpl")