[torch.compile][ROCm] Fuse quantization onto attention using a torch.compile pass (#16756)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -569,6 +569,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@@ -586,6 +587,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
@@ -547,6 +547,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashInfer.
|
||||
|
||||
@@ -561,6 +562,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashInferImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
@@ -414,6 +414,7 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FLexAttention.
|
||||
|
||||
@@ -427,6 +428,12 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlexAttentionImpl")
|
||||
|
||||
enable_gqa = self.num_kv_heads != self.num_heads
|
||||
|
||||
if attn_metadata is None:
|
||||
|
||||
@@ -865,10 +865,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLACommonImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
|
||||
@@ -161,6 +161,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
@@ -173,6 +174,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionBackendImpl")
|
||||
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
|
||||
@@ -142,6 +142,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@@ -156,6 +157,11 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TritonAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user