[Feature] Integrate new deepgemm (#19820)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-06-24 15:51:56 -04:00
committed by GitHub
parent 91f7d9d0b6
commit c6e3bba8e6
8 changed files with 230 additions and 264 deletions

View File

@@ -18,7 +18,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
dg_available = False
@@ -263,25 +264,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@@ -299,10 +281,8 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
_, block_k = block_size[0], block_size[1]
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
A_fp8, As_fp8 = deep_gemm.utils.math.per_token_cast_to_fp8(A_fp32)
B_fp8, Bs_fp8 = deep_gemm.utils.math.per_block_cast_to_fp8(B_fp32)
As = As_fp8.to(torch.float32)
Bs = Bs_fp8.to(torch.float32)
@@ -310,15 +290,12 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
deep_gemm.fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
@@ -382,16 +359,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
dtype=torch.bfloat16,
device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
deep_gemm.m_grouped_fp8_gemm_nt_contiguous((act_out_q, act_out_s),
(w2, w2_s), out, m_indices)
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
@@ -441,15 +418,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
w1_s = get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = get_col_major_tma_aligned_tensor(w2_s).contiguous()
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(E):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
w1[i], w1_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w2_bf16[i])
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
@@ -460,14 +437,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))