[Performance][ROCm] Add skinny gemms for unquantized linear on ROCm (#15830)
Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -17,6 +18,7 @@ TORCH_DEVICE_IDENTITY = None
|
||||
# The condition is determined once as the operations
|
||||
# are time consuming.
|
||||
USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm()
|
||||
and torch.__version__[0:3] >= "2.7"
|
||||
and current_platform.has_device_capability(94))
|
||||
|
||||
|
||||
@@ -131,6 +133,159 @@ def maybe_create_device_identity():
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
out_dtype: torch.dtype, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: List, **kwargs) -> torch.Tensor:
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
if envs.VLLM_ROCM_USE_SKINNY_GEMM and qinput.shape[
|
||||
0] == 1 and qinput.shape[1] % 16 == 0:
|
||||
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
else:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# For now it has only been validated on ROCm platform.
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
#
|
||||
# For CUDA platform please validate if the torch._scaled_mm supports
|
||||
# rowwise scaled GEMM before using it
|
||||
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b.t(),
|
||||
bias=bias)
|
||||
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
|
||||
def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: List,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * scale_b.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype).view(*output_shape)
|
||||
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
cutlass_fp8_supported: bool, per_tensor_weights: bool,
|
||||
per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool]
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
|
||||
if cutlass_fp8_supported:
|
||||
return cutlass_w8a8_scaled_mm
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if current_platform.is_rocm():
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
if (use_per_token_if_dynamic and not per_tensor_weights
|
||||
and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM):
|
||||
return torch_per_token_w8a8_scaled_mm
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
|
||||
|
||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class Fp8LinearOp:
|
||||
@@ -156,7 +311,8 @@ class Fp8LinearOp:
|
||||
if pad_output is None:
|
||||
config = get_current_vllm_config().compilation_config
|
||||
pad_output = config.level < CompilationLevel.PIECEWISE
|
||||
self.output_padding = 17 if pad_output else None
|
||||
self.output_padding = 17 if (
|
||||
pad_output and not current_platform.is_rocm()) else None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -195,18 +351,6 @@ class Fp8LinearOp:
|
||||
input_scale,
|
||||
scale_ub=input_scale_ub,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
return output.view(*output_shape)
|
||||
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
else:
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
# Maybe apply padding to output, see comment in __init__
|
||||
@@ -218,84 +362,21 @@ class Fp8LinearOp:
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
# Fused GEMM_DQ
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.cutlass_fp8_supported, per_tensor_weights,
|
||||
per_tensor_activations, use_per_token_if_dynamic)
|
||||
|
||||
return torch.narrow(output, 0, 0,
|
||||
input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
elif (use_per_token_if_dynamic and not per_tensor_weights
|
||||
and not per_tensor_activations
|
||||
and USE_ROWWISE_TORCH_SCALED_MM):
|
||||
# For now validated on ROCm platform
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
||||
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
# For CUDA platform please validate if the
|
||||
# torch._scaled_mm support rowwise scaled GEMM
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.t(),
|
||||
bias=bias)
|
||||
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
# due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(qinput,
|
||||
weight,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * weight_scale.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
return w8a8_scaled_mm_func(qinput=qinput,
|
||||
weight=weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
input_2d=input_2d,
|
||||
output_shape=output_shape)
|
||||
|
||||
|
||||
def normalize_e4m3fn_to_e4m3fnuz(
|
||||
|
||||
Reference in New Issue
Block a user