[Perf] Use upstream CUTLASS for SM90 Block FP8 kernel (#23280)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -11,8 +11,8 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul)
|
||||
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
|
||||
@@ -98,6 +98,54 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_cutlass_matmul():
|
||||
# Test simple case where weight.shape % 128 != 0,
|
||||
# like in DSV3 kv_a_proj_with_mqa
|
||||
M = 32
|
||||
N = 576
|
||||
K = 7168
|
||||
block_size = [128, 128]
|
||||
out_dtype = torch.bfloat16
|
||||
seed = 0
|
||||
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
# Hopper requires row-major format for scales
|
||||
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(
|
||||
90) else Bs
|
||||
|
||||
A_fp8, As = per_token_group_quant_fp8(A_fp32,
|
||||
block_size[1],
|
||||
column_major_scales=False)
|
||||
# CUTLASS uses column-major format for scales
|
||||
A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=True)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass,
|
||||
block_size, out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
|
||||
Reference in New Issue
Block a user