[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends (#19767)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Gregory Shtrasberg
2025-09-10 16:59:55 -04:00
committed by GitHub
parent 37e8182bfe
commit 9a161307f5
8 changed files with 249 additions and 135 deletions

View File

@@ -171,10 +171,12 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
bias=bias)
def rocm_per_tensor_w8a8_scaled_mm_impl(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor) -> torch.Tensor:
def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
@@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
return output
def rocm_per_tensor_w8a8_scaled_mm_fake(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor) -> torch.Tensor:
def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
dtype=out_dtype)
@@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
qinput, weight, out_dtype, scale_a, scale_b, bias)
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
direct_register_custom_op(
@@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
output = torch._scaled_mm(qinput,
weight,
@@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
@@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor, output_shape: list,
output_shape: list,
**kwargs) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
@@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b=scale_b.t(),
bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
output = torch.narrow(output, 0, 0, qinput.shape[0])
output = output.view(*output_shape)
return output
@@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list,
**kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
@@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
output = torch.narrow(output, 0, 0, qinput.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
@@ -430,7 +431,6 @@ class Fp8LinearOp:
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
input_2d=input_2d,
output_shape=output_shape)