[Feature] Integrate SM100 DeepGEMM support (#20087)

This commit is contained in:
Wentao Ye
2025-07-10 23:18:05 -04:00
committed by GitHub
parent 5b032352cc
commit e2de455c34
16 changed files with 397 additions and 114 deletions

View File

@@ -13,48 +13,18 @@ import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import cdiv
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
per_token_group_cast_to_fp8)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
if has_deep_gemm:
import deep_gemm
BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout()
BLOCK_SIZE = [BLOCK_M, BLOCK_M]
BLOCK_SIZE = [128, 128]
requires_deep_gemm = pytest.mark.skipif(
not has_deep_gemm,
not has_deep_gemm(),
reason="Requires deep_gemm kernels",
)
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights(
e: int,
n: int,
@@ -111,7 +81,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
"""
tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
# expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
@@ -155,17 +125,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
block_shape=block_size,
allow_deep_gemm=True,
)
base = out_triton.abs().mean()
atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3
rtol = 0.05
# ----- Compare -----
torch.testing.assert_close(
out_deepgemm.to(torch.float32),
out_triton.to(torch.float32),
rtol=rtol,
atol=float(atol),
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
# Note: W1 has shape (E, 2N, K), so N = 512