[Feature] Integrate new deepgemm (#19820)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -18,7 +18,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
dg_available = False
|
||||
@@ -263,25 +264,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
assert rel_diff < 0.03
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(deep_gemm.ceil_div(m, 128) * 128,
|
||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
@@ -299,10 +281,8 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
_, block_k = block_size[0], block_size[1]
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
||||
A_fp8, As_fp8 = deep_gemm.utils.math.per_token_cast_to_fp8(A_fp32)
|
||||
B_fp8, Bs_fp8 = deep_gemm.utils.math.per_block_cast_to_fp8(B_fp32)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
@@ -310,15 +290,12 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) //
|
||||
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
deep_gemm.fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
@@ -382,16 +359,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
dtype=torch.bfloat16,
|
||||
device=a.device)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
|
||||
inter_out, m_indices)
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a_q, a_s), (w1, w1_s),
|
||||
inter_out, m_indices)
|
||||
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
|
||||
|
||||
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous((act_out_q, act_out_s),
|
||||
(w2, w2_s), out, m_indices)
|
||||
|
||||
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
|
||||
|
||||
@@ -441,15 +418,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
|
||||
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
|
||||
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
|
||||
w1_s = get_col_major_tma_aligned_tensor(w1_s).contiguous()
|
||||
w2_s = get_col_major_tma_aligned_tensor(w2_s).contiguous()
|
||||
|
||||
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(E):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
w1[i], w1_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
@@ -460,14 +437,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
|
||||
topk, block_size)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score.float(), topk, False)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
|
||||
Reference in New Issue
Block a user