[Perf][Kernels] Enable FlashInfer DeepGEMM swapAB on SM90 (for W8A8 Linear Op) (#29213)
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com> Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com> Co-authored-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
This commit is contained in:
@@ -24,6 +24,10 @@ from vllm.utils.deep_gemm import (
|
||||
per_block_cast_to_fp8,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_fp8_blockscale_gemm,
|
||||
has_flashinfer_fp8_blockscale_gemm,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
@@ -205,3 +209,50 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_flashinfer_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
if not has_flashinfer_fp8_blockscale_gemm():
|
||||
pytest.skip(
|
||||
"FlashInfer block GEMM not available (requires SM90+ and FlashInfer)"
|
||||
)
|
||||
# only aligned sizes
|
||||
if K % 128 != 0 or N % 64 != 0:
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = fp8_info.max
|
||||
|
||||
A_bf16 = (torch.rand(M, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||
B_bf16 = (torch.rand(N, K, dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_bf16, block_size[1], use_ue8m0=False)
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_bf16, block_size, use_ue8m0=False)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
out = flashinfer_fp8_blockscale_gemm(
|
||||
input=A_bf16,
|
||||
weight=B_fp8,
|
||||
input_scale=None,
|
||||
weight_scale=Bs,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.bfloat16) - ref_out.to(torch.bfloat16))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.bfloat16)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
Reference in New Issue
Block a user