[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -483,6 +483,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
|
||||
@@ -497,7 +498,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TorchSDPABackendImpl")
|
||||
|
||||
@@ -430,6 +430,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@@ -447,7 +448,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
@@ -12,6 +12,7 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper)
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
from flashinfer.utils import FP4Tensor
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
@@ -19,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
||||
@@ -40,6 +41,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -653,14 +655,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
and num_heads % num_kv_heads == 0)
|
||||
self.bmm1_scale: Optional[float] = None
|
||||
self.bmm2_scale: Optional[float] = None
|
||||
self.o_sf_scale: Optional[float] = None
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: GroupShape):
|
||||
supported_quant_type = (dtype == FP8_DTYPE and static
|
||||
and group_shape == GroupShape.PER_TENSOR)
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return (self.support_trtllm_attn
|
||||
and self.kv_cache_dtype.startswith("fp8")
|
||||
and supported_quant_type)
|
||||
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -672,6 +672,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashInfer.
|
||||
|
||||
@@ -705,19 +706,32 @@ class FlashInferImpl(AttentionImpl):
|
||||
if output_scale is None:
|
||||
assert attn_metadata.q_data_type != FP8_DTYPE, \
|
||||
"Query can only be FP8 if output fusion happened."
|
||||
assert output_block_scale is None, "output_block_scale "\
|
||||
"is not supported when fusion has not happened"
|
||||
else:
|
||||
assert attn_metadata.q_data_type == FP8_DTYPE, \
|
||||
"Query must be FP8 when attn+quant fusion happened."
|
||||
assert (attn_metadata.prefill_use_trtllm and
|
||||
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
|
||||
assert output.dtype == FP8_DTYPE, \
|
||||
"Output must be FP8 when attn+quant fusion happened."
|
||||
|
||||
# TRTLLM attn kernel requires o scale as a host scalar, store the
|
||||
# o scale to host scalar in warmup run with cuda graph not enabled
|
||||
if output.dtype == FP8_DTYPE:
|
||||
assert output_block_scale is None, \
|
||||
"output_block_scale should not be provided for fp8 output"
|
||||
elif output.dtype == FP4_DTYPE:
|
||||
assert output_block_scale is not None, \
|
||||
"output_block_scale is required for nvfp4 output"
|
||||
else:
|
||||
raise ValueError(f"Unsupported output dtype: {output.dtype}")
|
||||
|
||||
# TRTLLM attn kernel requires o scale to pass as a host scalar,
|
||||
# store the o scale as a host scalar in warmup run with cuda graph
|
||||
# not enabled
|
||||
if layer._o_scale_float is None:
|
||||
layer._o_scale_float = output_scale.cpu().item()
|
||||
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
|
||||
if output.dtype == FP8_DTYPE:
|
||||
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
|
||||
elif output.dtype == FP4_DTYPE:
|
||||
self.o_sf_scale = layer._o_scale_float
|
||||
|
||||
# Insert FP8 quant for query
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
@@ -818,6 +832,16 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert block_tables_prefill.is_contiguous()
|
||||
assert seq_lens_prefill.is_contiguous()
|
||||
|
||||
if output.dtype == FP4_DTYPE:
|
||||
assert self.o_sf_scale is not None
|
||||
out = FP4Tensor(data=output[num_decode_tokens:],
|
||||
scale=output_block_scale,
|
||||
scale_start_index=num_decode_tokens,
|
||||
original_shape=prefill_query.shape)
|
||||
else:
|
||||
assert self.o_sf_scale is None
|
||||
out = output[num_decode_tokens:]
|
||||
|
||||
trtllm_batch_context_with_kv_cache(
|
||||
query=prefill_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
@@ -833,7 +857,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||
window_left=self.window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[num_decode_tokens:],
|
||||
o_sf_scale=self.o_sf_scale,
|
||||
out=out,
|
||||
)
|
||||
|
||||
if num_decode_tokens > 0:
|
||||
@@ -870,6 +895,16 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert block_tables_decode.is_contiguous()
|
||||
assert seq_lens_decode.is_contiguous()
|
||||
|
||||
if output.dtype == FP4_DTYPE:
|
||||
assert self.o_sf_scale is not None
|
||||
out = FP4Tensor(data=output[:num_decode_tokens],
|
||||
scale=output_block_scale,
|
||||
scale_start_index=0,
|
||||
original_shape=decode_query.shape)
|
||||
else:
|
||||
assert self.o_sf_scale is None
|
||||
out = output[:num_decode_tokens]
|
||||
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
@@ -881,7 +916,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
window_left=self.window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[:num_decode_tokens],
|
||||
o_sf_scale=self.o_sf_scale,
|
||||
out=out,
|
||||
)
|
||||
return output_padded
|
||||
|
||||
|
||||
@@ -428,6 +428,7 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FLexAttention.
|
||||
|
||||
@@ -441,7 +442,7 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlexAttentionImpl")
|
||||
|
||||
@@ -1138,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLACommonImpl")
|
||||
|
||||
@@ -227,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
@@ -239,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionBackendImpl")
|
||||
|
||||
@@ -421,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with AiterFlashAttention.
|
||||
|
||||
@@ -438,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
@@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl):
|
||||
attn_metadata: TreeAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with TreeAttention.
|
||||
|
||||
@@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TreeAttentionImpl")
|
||||
|
||||
@@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TritonAttentionImpl")
|
||||
|
||||
@@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl):
|
||||
attn_metadata: XFormersAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with XFormers.
|
||||
|
||||
@@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for XFormersAttentionImpl")
|
||||
|
||||
Reference in New Issue
Block a user