[Perf] Create TMA-aligned input scale tensor for DeepGemm on Hopper (#32619)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -14,7 +14,6 @@ from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
calc_diff,
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
per_block_cast_to_fp8,
|
||||
)
|
||||
|
||||
@@ -48,8 +47,9 @@ def benchmark_shape(
|
||||
block_size = [128, 128]
|
||||
|
||||
# Pre-quantize A for all implementations
|
||||
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
|
||||
A, block_size[1], column_major_scales=True, tma_aligned_scales=True
|
||||
)
|
||||
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import _ceil_to_ue8m0, is_deep_gemm_e8m0_used
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -170,6 +171,8 @@ def native_per_token_group_quant_fp8(
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
if is_deep_gemm_e8m0_used():
|
||||
x_s = _ceil_to_ue8m0(x_s)
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_tma_aligned_size,
|
||||
per_block_cast_to_fp8,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
@@ -40,6 +40,8 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 2050]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
GROUP_SIZE = [64, 128, 512]
|
||||
COLUMN_MAJOR_SCALES = [True, False]
|
||||
TMA_ALIGNED_SCALES = [True, False]
|
||||
M = [1, 7, 8, 83, 84, 4096]
|
||||
N = [128, 512, 7168, 7748, 13824]
|
||||
K = [256, 3884, 4096, 13824, 16384]
|
||||
@@ -63,20 +65,40 @@ def setup_cuda():
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,d,dtype,group_size,seed",
|
||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
|
||||
"num_tokens,d,dtype,group_size,column_major_scales,tma_aligned_scales,seed",
|
||||
itertools.product(
|
||||
NUM_TOKENS,
|
||||
D,
|
||||
DTYPES,
|
||||
GROUP_SIZE,
|
||||
COLUMN_MAJOR_SCALES,
|
||||
TMA_ALIGNED_SCALES,
|
||||
SEEDS,
|
||||
),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
||||
def test_per_token_group_quant_fp8(
|
||||
num_tokens, d, dtype, group_size, column_major_scales, tma_aligned_scales, seed
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
x = torch.rand(num_tokens, d, dtype=dtype)
|
||||
|
||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
|
||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||
out, scale = per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major_scales,
|
||||
tma_aligned_scales=tma_aligned_scales,
|
||||
)
|
||||
|
||||
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
|
||||
assert torch.allclose(scale, ref_scale)
|
||||
|
||||
if column_major_scales:
|
||||
assert scale.stride()[-2] == 1
|
||||
if tma_aligned_scales:
|
||||
assert scale.stride()[-1] == get_tma_aligned_size(num_tokens, 4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
@@ -186,7 +208,9 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
):
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=True, tma_aligned_scales=True
|
||||
)
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
@@ -194,9 +218,6 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) // 128), (
|
||||
|
||||
@@ -8,13 +8,16 @@ import torch
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
||||
@pytest.mark.parametrize(
|
||||
"shape", [(31, 128), (32, 128), (63, 256), (64, 256), (16, 512)]
|
||||
)
|
||||
@pytest.mark.parametrize("column_major", [False, True])
|
||||
@pytest.mark.parametrize("tma_aligned", [False, True])
|
||||
@pytest.mark.parametrize("scale_ue8m0", [False, True])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_per_token_group_quant_fp8(
|
||||
shape, column_major: bool, scale_ue8m0: bool, group_size: int
|
||||
shape, column_major: bool, tma_aligned: bool, scale_ue8m0: bool, group_size: int
|
||||
):
|
||||
device = "cuda"
|
||||
|
||||
@@ -28,6 +31,7 @@ def test_per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
tma_aligned_scales=tma_aligned,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ class QuantFP8(CustomOp):
|
||||
group_shape: GroupShape,
|
||||
num_token_padding: int | None = None,
|
||||
column_major_scales: bool = False,
|
||||
tma_aligned_scales: bool = False,
|
||||
use_ue8m0: bool | None = None, # for Torch compile
|
||||
):
|
||||
"""
|
||||
@@ -44,6 +45,8 @@ class QuantFP8(CustomOp):
|
||||
PER_CHANNEL, or arbitrary block size)
|
||||
:param num_token_padding: Pad the token dimension of output to this
|
||||
size
|
||||
:param tma_aligned_scales: For group quantization, output scales in
|
||||
TMA-aligned layout
|
||||
:param column_major_scales: For group quantization, output scales in
|
||||
column major format
|
||||
"""
|
||||
@@ -53,6 +56,7 @@ class QuantFP8(CustomOp):
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
self.tma_aligned_scales = tma_aligned_scales
|
||||
self.use_ue8m0 = use_ue8m0
|
||||
|
||||
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
@@ -82,6 +86,7 @@ class QuantFP8(CustomOp):
|
||||
x,
|
||||
group_size=self.group_size,
|
||||
column_major_scales=self.column_major_scales,
|
||||
tma_aligned_scales=self.tma_aligned_scales,
|
||||
dtype=_FP8_DTYPE,
|
||||
use_ue8m0=self.use_ue8m0,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
fp8_gemm_nt,
|
||||
get_tma_aligned_size,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
@@ -378,6 +379,7 @@ class W8A8BlockFp8LinearOp:
|
||||
False,
|
||||
self.act_quant_group_shape,
|
||||
column_major_scales=True,
|
||||
tma_aligned_scales=True,
|
||||
use_ue8m0=self.use_deep_gemm_e8m0,
|
||||
)
|
||||
if self.is_deep_gemm_supported
|
||||
@@ -868,6 +870,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype | None = None,
|
||||
column_major_scales: bool = False,
|
||||
tma_aligned_scales: bool = False,
|
||||
out_q: torch.Tensor | None = None,
|
||||
use_ue8m0: bool | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -878,9 +881,10 @@ def per_token_group_quant_fp8(
|
||||
x: The input tensor with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
dtype: The dtype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
column_major_scales: Outputs scales in column major.
|
||||
tma_aligned_scales: Outputs scales in TMA-aligned layout.
|
||||
out_q: Optional output tensor. If not provided, function will create.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
@@ -904,8 +908,24 @@ def per_token_group_quant_fp8(
|
||||
|
||||
# Allocate the scale tensor in either row- or column-major format.
|
||||
if column_major_scales:
|
||||
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
if tma_aligned_scales:
|
||||
m = x.shape[-2]
|
||||
sf_k = x.shape[-1] // group_size
|
||||
tma_aligned_m = get_tma_aligned_size(m, 4)
|
||||
shape = x.shape[:-2] + (m, sf_k)
|
||||
stride = (
|
||||
(1, tma_aligned_m)
|
||||
if x.dim() == 2
|
||||
else (tma_aligned_m * sf_k, 1, tma_aligned_m)
|
||||
)
|
||||
x_s = torch.empty_strided(
|
||||
shape, stride, device=x.device, dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
shape = x.shape[:-2] + (x.shape[-1] // group_size, x.shape[-2])
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(
|
||||
-1, -2
|
||||
)
|
||||
else:
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
@@ -340,6 +340,11 @@ def _align(x: int, y: int) -> int:
|
||||
return cdiv(x, y) * y
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
|
||||
def get_tma_aligned_size(x: int, element_size: int):
|
||||
return _align(x, 16 // element_size)
|
||||
|
||||
|
||||
DEFAULT_BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user