[Perf] Create TMA-aligned input scale tensor for DeepGemm on Hopper (#32619)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_tma_aligned_size,
|
||||
per_block_cast_to_fp8,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
@@ -40,6 +40,8 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 2050]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
GROUP_SIZE = [64, 128, 512]
|
||||
COLUMN_MAJOR_SCALES = [True, False]
|
||||
TMA_ALIGNED_SCALES = [True, False]
|
||||
M = [1, 7, 8, 83, 84, 4096]
|
||||
N = [128, 512, 7168, 7748, 13824]
|
||||
K = [256, 3884, 4096, 13824, 16384]
|
||||
@@ -63,20 +65,40 @@ def setup_cuda():
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,d,dtype,group_size,seed",
|
||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
|
||||
"num_tokens,d,dtype,group_size,column_major_scales,tma_aligned_scales,seed",
|
||||
itertools.product(
|
||||
NUM_TOKENS,
|
||||
D,
|
||||
DTYPES,
|
||||
GROUP_SIZE,
|
||||
COLUMN_MAJOR_SCALES,
|
||||
TMA_ALIGNED_SCALES,
|
||||
SEEDS,
|
||||
),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
||||
def test_per_token_group_quant_fp8(
|
||||
num_tokens, d, dtype, group_size, column_major_scales, tma_aligned_scales, seed
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
x = torch.rand(num_tokens, d, dtype=dtype)
|
||||
|
||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
|
||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||
out, scale = per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major_scales,
|
||||
tma_aligned_scales=tma_aligned_scales,
|
||||
)
|
||||
|
||||
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
|
||||
assert torch.allclose(scale, ref_scale)
|
||||
|
||||
if column_major_scales:
|
||||
assert scale.stride()[-2] == 1
|
||||
if tma_aligned_scales:
|
||||
assert scale.stride()[-1] == get_tma_aligned_size(num_tokens, 4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
@@ -186,7 +208,9 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
):
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=True, tma_aligned_scales=True
|
||||
)
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
@@ -194,9 +218,6 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) // 128), (
|
||||
|
||||
Reference in New Issue
Block a user