[Performance][DP/EP] Add silu_mul_per_token_group_quant_fp8_colmajor kernel (#29470)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-12-03 13:04:59 -05:00
committed by GitHub
parent dd5d1ef780
commit 19bee6d12d
4 changed files with 496 additions and 81 deletions

View File

@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from tqdm import tqdm
import vllm.envs as env
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
@@ -25,12 +23,12 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
m_grouped_fp8_gemm_nt_contiguous,
)
from vllm.utils.func_utils import run_once
from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
@@ -108,70 +106,6 @@ def _valid_deep_gemm(
return True
@run_once
def warmup_deepgemm_gg_contiguous_kernels(
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
num_topk: int,
):
"""
DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the
input tensor shapes. In this function, we construct all possible input
tensor shapes so all the kernels are JIT'ed and cached.
Note that this warmup is expected to happen during the model profile
call and not during actual model inference.
"""
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
block_m = get_mk_alignment_for_contiguous_layout()[0]
num_experts = w1.size(0)
device = w1.device
# This is the maximum GroupedGemm M size that we expect to run
# the grouped_gemm with.
MAX_M = compute_aligned_M(
env.VLLM_FUSED_MOE_CHUNK_SIZE,
num_topk,
num_experts,
block_m,
expert_tokens_meta=None,
)
# Distribute expert-ids evenly.
MAX_BLOCKS = MAX_M // block_m
expert_ids_block = torch.randint(
low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32
)
expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0)
def _warmup(w: torch.Tensor, w_scale: torch.Tensor):
_, n, k = w.size()
a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn)
a1q_scales = torch.empty(
(MAX_M, k // block_m), device=device, dtype=torch.float32
)
out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16)
pbar = tqdm(
total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})"
)
num_tokens = MAX_M
while num_tokens > 0:
m_grouped_fp8_gemm_nt_contiguous(
(a1q[:num_tokens], a1q_scales[:num_tokens]),
(w, w_scale),
out[:num_tokens],
expert_ids[:num_tokens],
)
pbar.update(1)
num_tokens = num_tokens - block_m
_warmup(w1, w1_scale)
_warmup(w2, w2_scale)
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
@@ -215,11 +149,32 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
assert M_sum % block_m == 0
workspace1 = (M_sum, N)
workspace2 = (M_sum, max(N // 2, K))
workspace1 = (M_sum, max(N // 2, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: str
) -> tuple[torch.Tensor, torch.Tensor]:
if activation == "silu":
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input, output=output
)
else:
# This is a fallback path. If we find ourselves using any activation other
# than silu, we should add that activation to
# silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster.
M_sum, N = input.size()
act_out = torch.empty(
(M_sum, N // 2), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
assert self.block_shape is not None
return per_token_group_quant_fp8(
act_out, self.block_shape[1], column_major_scales=True, out_q=output
)
def apply(
self,
output: torch.Tensor,
@@ -261,14 +216,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K))
mm1_out = _resize_cache(workspace13, (M_sum, N))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
)
mm2_out = _resize_cache(workspace2, (M_sum, K))
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
@@ -280,17 +230,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
assert a1q.size(0) == M_sum
mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_gemm_nt_contiguous(
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
)
self.activation(activation, act_out, mm1_out.view(-1, N))
a2q_scale: torch.Tensor | None = None
a2q, a2q_scale = per_token_group_quant_fp8(
act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_gemm_nt_contiguous(
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
)

View File

@@ -492,6 +492,139 @@ def _per_token_group_quant_fp8(
tl.store(y_s_ptr, y_s)
@triton.jit
def _silu_mul_per_token_group_quant_fp8_colmajor(
y_ptr, # [M, N]
y_q_ptr, # [M, N // 2]
y_s_ptr, # [M, (N // 2) // GROUP_SIZE]
M, # num tokens
N, # intermediate size
# Stride
y_s_col_stride: tl.int64,
# Information for float8
eps,
fp8_min,
fp8_max,
use_ue8m0: tl.constexpr,
# Meta-parameters
GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# TODO(varun) : Add expert_ids so we may early-exit no-op thread blocks.
"""
Each thread block (BLOCK_N) computes [BLOCK_M, GROUP_SIZE] act-mul outputs. Then
the thread block quantizes the [BLOCK_M, GROUP_SIZE] block of values and fills
the outputs tensors at the right positions.
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
N_2 = N // 2
m_offset = pid_m * BLOCK_M
n_offset = pid_n * BLOCK_N
if m_offset >= M:
return
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_m = tl.arange(0, BLOCK_M).to(tl.int64)
base_y_ptr = y_ptr + m_offset * N + n_offset
act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :]
act_in = tl.load(act_in_ptrs)
mul_in = tl.load(act_in_ptrs + N_2)
# silu & mul
act_in = act_in.to(tl.float32)
one_f32 = tl.cast(1, tl.float32)
silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty)
y = (silu_out * mul_in).to(tl.float32)
# quant
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_s = tl.reshape(y_s, (BLOCK_M, 1))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
# store y_q
base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset
y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :]
tl.store(y_q_ptrs, y_q)
# store y_s
group_id = n_offset // GROUP_SIZE
base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset
y_s_ptrs = base_y_s_ptr + offs_m
y_s = tl.reshape(y_s, (BLOCK_M,))
tl.store(y_s_ptrs, y_s)
def silu_mul_per_token_group_quant_fp8_colmajor(
input: torch.Tensor, # [M, N]
output: torch.Tensor | None = None, # [M, N // 2]
use_ue8m0: bool | None = None,
eps: float = 1e-10,
):
"""
silu+mul + block-fp8 quant with group size 128.
"""
GROUP_SIZE = 128
assert input.ndim == 2
if output is not None:
assert output.ndim == 2
assert input.size(0) % GROUP_SIZE == 0
assert input.size(1) % (GROUP_SIZE * 2) == 0
if use_ue8m0 is None:
use_ue8m0 = is_deep_gemm_e8m0_used()
M, N = input.size()
N_2 = N // 2
if output is None:
output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device)
output_scales = torch.empty(
((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device
).transpose(0, 1)
BLOCK_M = 8
BLOCK_N = GROUP_SIZE
assert M % BLOCK_M == 0
assert N_2 % BLOCK_N == 0
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_min = finfo.min
fp8_max = finfo.max
# Force even division so we can avoid edgecases within the kernel.
assert M % BLOCK_M == 0
assert N_2 % BLOCK_N == 0
grid = (M // BLOCK_M, N_2 // BLOCK_N)
_silu_mul_per_token_group_quant_fp8_colmajor[grid](
input,
output,
output_scales,
M,
N,
output_scales.stride(-1),
eps,
fp8_min,
fp8_max,
use_ue8m0,
GROUP_SIZE,
BLOCK_M,
BLOCK_N,
)
return output, output_scales
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output