[Perf] Deepgemm fused layout kernel for activations, 4.3% throughput improvement, 10.7% TTFT improvement. (#29546)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -23,9 +23,11 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
per_token_group_quant_fp8_packed_for_deepgemm,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
)
|
||||
@@ -157,23 +159,40 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def _act_mul_quant(
|
||||
self, input: torch.Tensor, output: torch.Tensor, activation: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if activation == "silu":
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input=input, output=output
|
||||
)
|
||||
else:
|
||||
# This is a fallback path. If we find ourselves using any activation other
|
||||
# than silu, we should add that activation to
|
||||
# silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster.
|
||||
assert self.block_shape is not None
|
||||
block_k = self.block_shape[1]
|
||||
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
|
||||
|
||||
# 1. DeepGemm UE8M0: use packed per-token-group quant
|
||||
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
|
||||
M_sum, N = input.size()
|
||||
act_out = torch.empty(
|
||||
(M_sum, N // 2), dtype=input.dtype, device=input.device
|
||||
)
|
||||
self.activation(activation, act_out, input)
|
||||
assert self.block_shape is not None
|
||||
return per_token_group_quant_fp8(
|
||||
act_out, self.block_shape[1], column_major_scales=True, out_q=output
|
||||
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
act_out,
|
||||
block_k,
|
||||
out_q=output,
|
||||
)
|
||||
return a2q, a2q_scale
|
||||
|
||||
# 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
|
||||
if activation == "silu":
|
||||
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
input=input,
|
||||
output=output,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
# 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
|
||||
M_sum, N = input.size()
|
||||
act_out = torch.empty((M_sum, N // 2), dtype=input.dtype, device=input.device)
|
||||
self.activation(activation, act_out, input)
|
||||
return per_token_group_quant_fp8(
|
||||
act_out, block_k, column_major_scales=True, out_q=output
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@@ -269,7 +269,11 @@ class W8A8BlockFp8LinearOp:
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
input_2d,
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
@@ -791,6 +795,80 @@ def per_token_group_quant_fp8(
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
use_ue8m0: bool | None = None,
|
||||
out_q: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""FP8 per-token-group quantization for DeepGEMM.
|
||||
|
||||
Returns:
|
||||
(x_q, x_s_packed)
|
||||
x_q: FP8 activations, same shape as `x`.
|
||||
x_s_packed: Int32 tensor with logical shape
|
||||
[mn, ceil(num_groups_per_row / 4)], laid out with
|
||||
TMA-aligned stride along the packed-K dimension
|
||||
"""
|
||||
if use_ue8m0 is None:
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
# for DeepGEMM UE8M0-packed layout we *require* UE8M0 scales.
|
||||
assert use_ue8m0, (
|
||||
"per_token_group_quant_fp8_packed_for_deepgemm requires UE8M0 scales."
|
||||
)
|
||||
|
||||
dtype = current_platform.fp8_dtype()
|
||||
assert x.shape[-1] % group_size == 0, (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
f"by `group_size` {group_size}"
|
||||
)
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min, fp8_max = finfo.min, finfo.max
|
||||
|
||||
# compute DeepGEMM-style packed scale tensor shape.
|
||||
hidden_dim = x.shape[-1]
|
||||
mn = x.numel() // hidden_dim
|
||||
num_groups_per_row = hidden_dim // group_size
|
||||
k_num_packed_sf_k = (num_groups_per_row + 3) // 4
|
||||
tma_aligned_mn = ((mn + 3) // 4) * 4
|
||||
|
||||
x_s_packed = torch.empty_strided(
|
||||
(mn, k_num_packed_sf_k),
|
||||
(1, tma_aligned_mn),
|
||||
device=x.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# CUDA kernel path only (DeepGEMM + E8M0 is CUDA-specific).
|
||||
assert current_platform.is_cuda(), (
|
||||
"per_token_group_quant_fp8_packed_for_deepgemm is only valid on CUDA "
|
||||
"platforms using DeepGEMM."
|
||||
)
|
||||
|
||||
x_contiguous = x.contiguous()
|
||||
if out_q is not None:
|
||||
x_q_local = out_q
|
||||
else:
|
||||
x_q_local = torch.empty_like(x_contiguous, device=x.device, dtype=dtype)
|
||||
|
||||
torch.ops._C.per_token_group_fp8_quant_packed(
|
||||
x_contiguous,
|
||||
x_q_local,
|
||||
x_s_packed,
|
||||
group_size,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
)
|
||||
|
||||
# return a tensor with the original logical shape.
|
||||
x_q = x_q_local.view_as(x)
|
||||
return x_q, x_s_packed
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_triton_block_scaled_mm(
|
||||
# Pointers to inputs and output
|
||||
|
||||
Reference in New Issue
Block a user