[Perf] Create TMA-aligned input scale tensor for DeepGemm on Hopper (#32619)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-01-22 12:47:04 -08:00
committed by GitHub
parent f744810184
commit d08b356ee0
7 changed files with 75 additions and 17 deletions

View File

@@ -36,6 +36,7 @@ class QuantFP8(CustomOp):
group_shape: GroupShape,
num_token_padding: int | None = None,
column_major_scales: bool = False,
tma_aligned_scales: bool = False,
use_ue8m0: bool | None = None, # for Torch compile
):
"""
@@ -44,6 +45,8 @@ class QuantFP8(CustomOp):
PER_CHANNEL, or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param tma_aligned_scales: For group quantization, output scales in
TMA-aligned layout
:param column_major_scales: For group quantization, output scales in
column major format
"""
@@ -53,6 +56,7 @@ class QuantFP8(CustomOp):
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales
self.tma_aligned_scales = tma_aligned_scales
self.use_ue8m0 = use_ue8m0
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
@@ -82,6 +86,7 @@ class QuantFP8(CustomOp):
x,
group_size=self.group_size,
column_major_scales=self.column_major_scales,
tma_aligned_scales=self.tma_aligned_scales,
dtype=_FP8_DTYPE,
use_ue8m0=self.use_ue8m0,
)

View File

@@ -35,6 +35,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt,
get_tma_aligned_size,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
@@ -378,6 +379,7 @@ class W8A8BlockFp8LinearOp:
False,
self.act_quant_group_shape,
column_major_scales=True,
tma_aligned_scales=True,
use_ue8m0=self.use_deep_gemm_e8m0,
)
if self.is_deep_gemm_supported
@@ -868,6 +870,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10,
dtype: torch.dtype | None = None,
column_major_scales: bool = False,
tma_aligned_scales: bool = False,
out_q: torch.Tensor | None = None,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -878,9 +881,10 @@ def per_token_group_quant_fp8(
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
dtype: The dtype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
tma_aligned_scales: Outputs scales in TMA-aligned layout.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
@@ -904,8 +908,24 @@ def per_token_group_quant_fp8(
# Allocate the scale tensor in either row- or column-major format.
if column_major_scales:
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
if tma_aligned_scales:
m = x.shape[-2]
sf_k = x.shape[-1] // group_size
tma_aligned_m = get_tma_aligned_size(m, 4)
shape = x.shape[:-2] + (m, sf_k)
stride = (
(1, tma_aligned_m)
if x.dim() == 2
else (tma_aligned_m * sf_k, 1, tma_aligned_m)
)
x_s = torch.empty_strided(
shape, stride, device=x.device, dtype=torch.float32
)
else:
shape = x.shape[:-2] + (x.shape[-1] // group_size, x.shape[-2])
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(
-1, -2
)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)

View File

@@ -340,6 +340,11 @@ def _align(x: int, y: int) -> int:
return cdiv(x, y) * y
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
def get_tma_aligned_size(x: int, element_size: int):
return _align(x, 16 // element_size)
DEFAULT_BLOCK_SIZE = [128, 128]