[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends (#19767)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
37e8182bfe
commit
9a161307f5
@@ -15,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
@@ -202,6 +204,9 @@ def use_aiter_unified_attention() -> bool:
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -297,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for TritonAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
@@ -394,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
output_scale=output_scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
@@ -419,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
)
|
||||
output_scale=output_scale)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user