[Perf] Deepgemm fused layout kernel for activations, 4.3% throughput improvement, 10.7% TTFT improvement. (#29546)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-12-07 07:31:14 -05:00
committed by GitHub
parent b0f4866a77
commit 541a2ef892
5 changed files with 311 additions and 12 deletions

View File

@@ -23,9 +23,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
per_token_group_quant_fp8_packed_for_deepgemm,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
get_mk_alignment_for_contiguous_layout,
m_grouped_fp8_gemm_nt_contiguous,
)
@@ -157,23 +159,40 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: str
) -> tuple[torch.Tensor, torch.Tensor]:
if activation == "silu":
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input, output=output
)
else:
# This is a fallback path. If we find ourselves using any activation other
# than silu, we should add that activation to
# silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster.
assert self.block_shape is not None
block_k = self.block_shape[1]
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
# 1. DeepGemm UE8M0: use packed per-token-group quant
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
M_sum, N = input.size()
act_out = torch.empty(
(M_sum, N // 2), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
assert self.block_shape is not None
return per_token_group_quant_fp8(
act_out, self.block_shape[1], column_major_scales=True, out_q=output
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
act_out,
block_k,
out_q=output,
)
return a2q, a2q_scale
# 2. Hopper / nonE8M0: prefer the fused SiLU+mul+quant kernel
if activation == "silu":
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input,
output=output,
use_ue8m0=use_ue8m0,
)
# 3. fallback path for non-SiLU activations in nonUE8M0 cases.
M_sum, N = input.size()
act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device)
self.activation(activation, act_out, input)
return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output
)
def apply(
self,

View File

@@ -269,7 +269,11 @@ class W8A8BlockFp8LinearOp:
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
use_ue8m0=True,
)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
@@ -791,6 +795,80 @@ def per_token_group_quant_fp8(
return x_q, x_s
def per_token_group_quant_fp8_packed_for_deepgemm(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
use_ue8m0: bool | None = None,
out_q: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""FP8 per-token-group quantization for DeepGEMM.
Returns:
(x_q, x_s_packed)
x_q: FP8 activations, same shape as `x`.
x_s_packed: Int32 tensor with logical shape
[mn, ceil(num_groups_per_row / 4)], laid out with
TMA-aligned stride along the packed-K dimension
"""
if use_ue8m0 is None:
use_ue8m0 = is_deep_gemm_e8m0_used()
# for DeepGEMM UE8M0-packed layout we *require* UE8M0 scales.
assert use_ue8m0, (
"per_token_group_quant_fp8_packed_for_deepgemm requires UE8M0 scales."
)
dtype = current_platform.fp8_dtype()
assert x.shape[-1] % group_size == 0, (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}"
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min, fp8_max = finfo.min, finfo.max
# compute DeepGEMM-style packed scale tensor shape.
hidden_dim = x.shape[-1]
mn = x.numel() // hidden_dim
num_groups_per_row = hidden_dim // group_size
k_num_packed_sf_k = (num_groups_per_row + 3) // 4
tma_aligned_mn = ((mn + 3) // 4) * 4
x_s_packed = torch.empty_strided(
(mn, k_num_packed_sf_k),
(1, tma_aligned_mn),
device=x.device,
dtype=torch.int32,
)
# CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific).
assert current_platform.is_cuda(), (
"per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA "
"platforms using DeepGEMM."
)
x_contiguous = x.contiguous()
if out_q is not None:
x_q_local = out_q
else:
x_q_local = torch.empty_like(x_contiguous, device=x.device, dtype=dtype)
torch.ops._C.per_token_group_fp8_quant_packed(
x_contiguous,
x_q_local,
x_s_packed,
group_size,
eps,
fp8_min,
fp8_max,
)
# return a tensor with the original logical shape.
x_q = x_q_local.view_as(x)
return x_q, x_s_packed
@triton.jit
def _w8a8_triton_block_scaled_mm(
# Pointers to inputs and output