[Perf] Create TMA-aligned input scale tensor for DeepGemm on Hopper (#32619)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-01-22 12:47:04 -08:00
committed by GitHub
parent f744810184
commit d08b356ee0
7 changed files with 75 additions and 17 deletions

View File

@@ -14,7 +14,6 @@ from vllm.triton_utils import triton
from vllm.utils.deep_gemm import (
calc_diff,
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8,
)
@@ -48,8 +47,9 @@ def benchmark_shape(
block_size = [128, 128]
# Pre-quantize A for all implementations
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
A, block_size[1], column_major_scales=True, tma_aligned_scales=True
)
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(

View File

@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import _ceil_to_ue8m0, is_deep_gemm_e8m0_used
from vllm.utils.math_utils import round_up
FP8_DTYPE = current_platform.fp8_dtype()
@@ -170,6 +171,8 @@ def native_per_token_group_quant_fp8(
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
if is_deep_gemm_e8m0_used():
x_s = _ceil_to_ue8m0(x_s)
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
get_tma_aligned_size,
per_block_cast_to_fp8,
should_use_deepgemm_for_fp8_linear,
)
@@ -40,6 +40,8 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 2050]
D = [512, 4096, 5120, 13824]
GROUP_SIZE = [64, 128, 512]
COLUMN_MAJOR_SCALES = [True, False]
TMA_ALIGNED_SCALES = [True, False]
M = [1, 7, 8, 83, 84, 4096]
N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384]
@@ -63,20 +65,40 @@ def setup_cuda():
reason="This platform supports e4m3fnuz, not e4m3fn.",
)
@pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
"num_tokens,d,dtype,group_size,column_major_scales,tma_aligned_scales,seed",
itertools.product(
NUM_TOKENS,
D,
DTYPES,
GROUP_SIZE,
COLUMN_MAJOR_SCALES,
TMA_ALIGNED_SCALES,
SEEDS,
),
)
@torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
def test_per_token_group_quant_fp8(
num_tokens, d, dtype, group_size, column_major_scales, tma_aligned_scales, seed
):
torch.manual_seed(seed)
x = torch.rand(num_tokens, d, dtype=dtype)
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
out, scale = per_token_group_quant_fp8(x, group_size)
out, scale = per_token_group_quant_fp8(
x,
group_size,
column_major_scales=column_major_scales,
tma_aligned_scales=tma_aligned_scales,
)
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
assert torch.allclose(scale, ref_scale)
if column_major_scales:
assert scale.stride()[-2] == 1
if tma_aligned_scales:
assert scale.stride()[-1] == get_tma_aligned_size(num_tokens, 4)
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
@@ -186,7 +208,9 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
):
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
A_fp8, As_fp8 = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=True, tma_aligned_scales=True
)
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
As = As_fp8.to(torch.float32)
@@ -194,9 +218,6 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) // 128), (

View File

@@ -8,13 +8,16 @@ import torch
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
@pytest.mark.parametrize(
"shape", [(31, 128), (32, 128), (63, 256), (64, 256), (16, 512)]
)
@pytest.mark.parametrize("column_major", [False, True])
@pytest.mark.parametrize("tma_aligned", [False, True])
@pytest.mark.parametrize("scale_ue8m0", [False, True])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_per_token_group_quant_fp8(
shape, column_major: bool, scale_ue8m0: bool, group_size: int
shape, column_major: bool, tma_aligned: bool, scale_ue8m0: bool, group_size: int
):
device = "cuda"
@@ -28,6 +31,7 @@ def test_per_token_group_quant_fp8(
x,
group_size,
column_major_scales=column_major,
tma_aligned_scales=tma_aligned,
use_ue8m0=scale_ue8m0,
)

View File

@@ -36,6 +36,7 @@ class QuantFP8(CustomOp):
group_shape: GroupShape,
num_token_padding: int | None = None,
column_major_scales: bool = False,
tma_aligned_scales: bool = False,
use_ue8m0: bool | None = None, # for Torch compile
):
"""
@@ -44,6 +45,8 @@ class QuantFP8(CustomOp):
PER_CHANNEL, or arbitrary block size)
:param num_token_padding: Pad the token dimension of output to this
size
:param tma_aligned_scales: For group quantization, output scales in
TMA-aligned layout
:param column_major_scales: For group quantization, output scales in
column major format
"""
@@ -53,6 +56,7 @@ class QuantFP8(CustomOp):
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales
self.tma_aligned_scales = tma_aligned_scales
self.use_ue8m0 = use_ue8m0
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
@@ -82,6 +86,7 @@ class QuantFP8(CustomOp):
x,
group_size=self.group_size,
column_major_scales=self.column_major_scales,
tma_aligned_scales=self.tma_aligned_scales,
dtype=_FP8_DTYPE,
use_ue8m0=self.use_ue8m0,
)

View File

@@ -35,6 +35,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt,
get_tma_aligned_size,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
@@ -378,6 +379,7 @@ class W8A8BlockFp8LinearOp:
False,
self.act_quant_group_shape,
column_major_scales=True,
tma_aligned_scales=True,
use_ue8m0=self.use_deep_gemm_e8m0,
)
if self.is_deep_gemm_supported
@@ -868,6 +870,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10,
dtype: torch.dtype | None = None,
column_major_scales: bool = False,
tma_aligned_scales: bool = False,
out_q: torch.Tensor | None = None,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -878,9 +881,10 @@ def per_token_group_quant_fp8(
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
dtype: The dtype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
tma_aligned_scales: Outputs scales in TMA-aligned layout.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
@@ -904,8 +908,24 @@ def per_token_group_quant_fp8(
# Allocate the scale tensor in either row- or column-major format.
if column_major_scales:
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
if tma_aligned_scales:
m = x.shape[-2]
sf_k = x.shape[-1] // group_size
tma_aligned_m = get_tma_aligned_size(m, 4)
shape = x.shape[:-2] + (m, sf_k)
stride = (
(1, tma_aligned_m)
if x.dim() == 2
else (tma_aligned_m * sf_k, 1, tma_aligned_m)
)
x_s = torch.empty_strided(
shape, stride, device=x.device, dtype=torch.float32
)
else:
shape = x.shape[:-2] + (x.shape[-1] // group_size, x.shape[-2])
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(
-1, -2
)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)

View File

@@ -340,6 +340,11 @@ def _align(x: int, y: int) -> int:
return cdiv(x, y) * y
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
def get_tma_aligned_size(x: int, element_size: int):
return _align(x, 16 // element_size)
DEFAULT_BLOCK_SIZE = [128, 128]