[Feature] Integrate SM100 DeepGEMM support (#20087)

This commit is contained in:
Wentao Ye
2025-07-10 23:18:05 -04:00
committed by GitHub
parent 5b032352cc
commit e2de455c34
16 changed files with 397 additions and 114 deletions

View File

@@ -8,19 +8,15 @@ import pytest
import torch
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
per_block_cast_to_fp8)
native_w8a8_block_matmul)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
per_token_group_cast_to_fp8)
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@@ -106,7 +102,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@pytest.mark.skipif(not has_deep_gemm(),
reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
@@ -120,9 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
_, block_k = block_size[0], block_size[1]
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
As = As_fp8.to(torch.float32)
@@ -132,14 +127,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /