[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends (#19767)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
37e8182bfe
commit
9a161307f5
@@ -171,10 +171,12 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
bias=bias)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor) -> torch.Tensor:
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor) -> torch.Tensor:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
|
||||
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
@@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_fake(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor) -> torch.Tensor:
|
||||
def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
bias: torch.Tensor) -> torch.Tensor:
|
||||
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
|
||||
dtype=out_dtype)
|
||||
|
||||
@@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
qinput, weight, out_dtype, scale_a, scale_b, bias)
|
||||
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
@@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
@@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
@@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor, output_shape: list,
|
||||
output_shape: list,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
@@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
scale_b=scale_b.t(),
|
||||
bias=bias)
|
||||
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
output = torch.narrow(output, 0, 0, qinput.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
@@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
@@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
|
||||
output = torch.narrow(output, 0, 0, qinput.shape[0])
|
||||
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
@@ -430,7 +431,6 @@ class Fp8LinearOp:
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
input_2d=input_2d,
|
||||
output_shape=output_shape)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user