[Perf] Create TMA-aligned input scale tensor for DeepGemm on Hopper (#32619)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -36,6 +36,7 @@ class QuantFP8(CustomOp):
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: int | None = None,
|
||||
column_major_scales: bool = False,
|
||||
tma_aligned_scales: bool = False,
|
||||
use_ue8m0: bool | None = None, # for Torch compile
|
||||
):
|
||||
"""
|
||||
@@ -44,6 +45,8 @@ class QuantFP8(CustomOp):
|
||||
PER_CHANNEL, or arbitrary block size)
|
||||
:param num_token_padding: Pad the token dimension of output to this
|
||||
size
|
||||
:param tma_aligned_scales: For group quantization, output scales in
|
||||
TMA-aligned layout
|
||||
:param column_major_scales: For group quantization, output scales in
|
||||
column major format
|
||||
"""
|
||||
@@ -53,6 +56,7 @@ class QuantFP8(CustomOp):
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
self.tma_aligned_scales = tma_aligned_scales
|
||||
self.use_ue8m0 = use_ue8m0
|
||||
|
||||
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
@@ -82,6 +86,7 @@ class QuantFP8(CustomOp):
|
||||
x,
|
||||
group_size=self.group_size,
|
||||
column_major_scales=self.column_major_scales,
|
||||
tma_aligned_scales=self.tma_aligned_scales,
|
||||
dtype=_FP8_DTYPE,
|
||||
use_ue8m0=self.use_ue8m0,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
fp8_gemm_nt,
|
||||
get_tma_aligned_size,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
@@ -378,6 +379,7 @@ class W8A8BlockFp8LinearOp:
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
tma_aligned_scales=True,
|
||||
use_ue8m0=self.use_deep_gemm_e8m0,
|
||||
)
|
||||
if self.is_deep_gemm_supported
|
||||
@@ -868,6 +870,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype | None = None,
|
||||
column_major_scales: bool = False,
|
||||
tma_aligned_scales: bool = False,
|
||||
out_q: torch.Tensor | None = None,
|
||||
use_ue8m0: bool | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -878,9 +881,10 @@ def per_token_group_quant_fp8(
|
||||
x: The input tensor with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
dtype: The dtype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
column_major_scales: Outputs scales in column major.
|
||||
tma_aligned_scales: Outputs scales in TMA-aligned layout.
|
||||
out_q: Optional output tensor. If not provided, function will create.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
@@ -904,8 +908,24 @@ def per_token_group_quant_fp8(
|
||||
|
||||
# Allocate the scale tensor in either row- or column-major format.
|
||||
if column_major_scales:
|
||||
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
if tma_aligned_scales:
|
||||
m = x.shape[-2]
|
||||
sf_k = x.shape[-1] // group_size
|
||||
tma_aligned_m = get_tma_aligned_size(m, 4)
|
||||
shape = x.shape[:-2] + (m, sf_k)
|
||||
stride = (
|
||||
(1, tma_aligned_m)
|
||||
if x.dim() == 2
|
||||
else (tma_aligned_m * sf_k, 1, tma_aligned_m)
|
||||
)
|
||||
x_s = torch.empty_strided(
|
||||
shape, stride, device=x.device, dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
shape = x.shape[:-2] + (x.shape[-1] // group_size, x.shape[-2])
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(
|
||||
-1, -2
|
||||
)
|
||||
else:
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
@@ -340,6 +340,11 @@ def _align(x: int, y: int) -> int:
|
||||
return cdiv(x, y) * y
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
|
||||
def get_tma_aligned_size(x: int, element_size: int):
|
||||
return _align(x, 16 // element_size)
|
||||
|
||||
|
||||
DEFAULT_BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user