[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends (#19767)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Gregory Shtrasberg
2025-09-10 16:59:55 -04:00
committed by GitHub
parent 37e8182bfe
commit 9a161307f5
8 changed files with 249 additions and 135 deletions

View File

@@ -15,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import (
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
@@ -202,6 +204,9 @@ def use_aiter_unified_attention() -> bool:
class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
@@ -297,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
if output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
"fused block_scale output quantization is not yet supported"
" for TritonAttentionImpl")
if attn_metadata is None:
@@ -394,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks,
)
@@ -419,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
)
output_scale=output_scale)
return output