[Performance][B200] silu_mul_quant: pack scales in int32 (#28358)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-13 13:16:55 -05:00
committed by GitHub
parent fdfd5075aa
commit fe1cd7704d
7 changed files with 466 additions and 151 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -13,14 +14,33 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_m_grouped_gemm_nt_masked,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
)
from vllm.utils.math_utils import cdiv
logger = init_logger(__name__)
def scales_shape_stride_dtype(
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
shape = (E, T, G)
strides = (T * G, 1, T)
if quant_scale_fmt in [
DeepGemmQuantScaleFMT.FLOAT32,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
]:
return shape, strides, torch.float32
assert quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0
shape = (E, T, cdiv(G, 4))
strides = (T * cdiv(G, 4), 1, T)
return shape, strides, torch.int32
@triton.jit
def _silu_mul_fp8_quant_deep_gemm(
# Pointers ------------------------------------------------------------
@@ -49,7 +69,7 @@ def _silu_mul_fp8_quant_deep_gemm(
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
ceil_ue8m0: tl.constexpr,
# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
NUM_STAGES: tl.constexpr,
@@ -86,7 +106,7 @@ def _silu_mul_fp8_quant_deep_gemm(
y = gate * up
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
if use_ue8m0:
if ceil_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
@@ -100,7 +120,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16,
group_size: int = 128,
use_ue8m0: bool | None = None,
quant_scale_fmt: DeepGemmQuantScaleFMT = DeepGemmQuantScaleFMT.FLOAT32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
@@ -137,7 +157,13 @@ def persistent_masked_m_silu_mul_quant(
Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
* `y_s` depends on quant_scale_fmt,
- quant_scale_fmt == FLOAT32,
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
- quant_scale_fmt == E8M0,
`y_s`: Int32 tensor, shape (E, T, H // group_size // 4), strides (T*G, 1, T)
- quant_scale_fmt == E8M0_FLOAT32_SPARSE
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
Let NUM_WARPS be the number of warps in a single thread block and
`GROUP_SIZE = 128` be the size of the quantization group.
"""
@@ -155,17 +181,18 @@ def persistent_masked_m_silu_mul_quant(
fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T
ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
y_s = torch.empty_strided(
(E, T, G),
(stride_ys_e, stride_ys_t, stride_ys_g),
dtype=torch.float32,
ys_shape,
ys_strides,
dtype=ys_dtype,
device=y.device,
)
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
ceil_ue8m0 = quant_scale_fmt in [
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
DeepGemmQuantScaleFMT.UE8M0,
]
cuda_arch = current_platform.get_device_capability(
device_id=y.device.index
@@ -173,7 +200,7 @@ def persistent_masked_m_silu_mul_quant(
if cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, y_q, y_s, use_ue8m0
y, tokens_per_expert, y_q, y_s, ceil_ue8m0
)
else:
stride_cnt_e = tokens_per_expert.stride()[0]
@@ -189,6 +216,10 @@ def persistent_masked_m_silu_mul_quant(
fp8_max = f_info.max
fp8_min = f_info.min
eps: float = 1e-10
assert y_s.dtype == torch.float32, (
"_silu_mul_fp8_quant_deep_gemm does"
"not support {y_s.dtype} scales. Only torch.float32 supported."
)
_silu_mul_fp8_quant_deep_gemm[grid](
y,
y_q,
@@ -202,14 +233,14 @@ def persistent_masked_m_silu_mul_quant(
stride_yq_e,
stride_yq_t,
stride_yq_h,
stride_ys_e,
stride_ys_t,
stride_ys_g,
ys_strides[0],
ys_strides[1],
ys_strides[2],
stride_cnt_e,
eps,
fp8_min,
fp8_max,
is_deep_gemm_e8m0_used(),
ceil_ue8m0,
BLOCK=group_size,
NUM_STAGES=4,
num_warps=1,
@@ -255,7 +286,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
"""
return current_platform.is_device_capability(100)
return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
@@ -329,10 +360,17 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m,
)
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
workspace1, expert_num_tokens
workspace1,
expert_num_tokens,
quant_scale_fmt=quant_scale_fmt,
)
fp8_m_grouped_gemm_nt_masked(
(a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m
(a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m,
)