[Performance][B200] silu_mul_quant: pack scales in int32 (#28358)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
fdfd5075aa
commit
fe1cd7704d
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@@ -13,14 +14,33 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
fp8_m_grouped_gemm_nt_masked,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def scales_shape_stride_dtype(
|
||||
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
shape = (E, T, G)
|
||||
strides = (T * G, 1, T)
|
||||
if quant_scale_fmt in [
|
||||
DeepGemmQuantScaleFMT.FLOAT32,
|
||||
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
]:
|
||||
return shape, strides, torch.float32
|
||||
|
||||
assert quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0
|
||||
shape = (E, T, cdiv(G, 4))
|
||||
strides = (T * cdiv(G, 4), 1, T)
|
||||
return shape, strides, torch.int32
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_fp8_quant_deep_gemm(
|
||||
# Pointers ------------------------------------------------------------
|
||||
@@ -49,7 +69,7 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
eps: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
use_ue8m0: tl.constexpr,
|
||||
ceil_ue8m0: tl.constexpr,
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGES: tl.constexpr,
|
||||
@@ -86,7 +106,7 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
y = gate * up
|
||||
|
||||
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||
if use_ue8m0:
|
||||
if ceil_ue8m0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
@@ -100,7 +120,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens=16,
|
||||
group_size: int = 128,
|
||||
use_ue8m0: bool | None = None,
|
||||
quant_scale_fmt: DeepGemmQuantScaleFMT = DeepGemmQuantScaleFMT.FLOAT32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
@@ -137,7 +157,13 @@ def persistent_masked_m_silu_mul_quant(
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
* `y_s` depends on quant_scale_fmt,
|
||||
- quant_scale_fmt == FLOAT32,
|
||||
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
- quant_scale_fmt == E8M0,
|
||||
`y_s`: Int32 tensor, shape (E, T, H // group_size // 4), strides (T*G, 1, T)
|
||||
- quant_scale_fmt == E8M0_FLOAT32_SPARSE
|
||||
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
Let NUM_WARPS be the number of warps in a single thread block and
|
||||
`GROUP_SIZE = 128` be the size of the quantization group.
|
||||
"""
|
||||
@@ -155,17 +181,18 @@ def persistent_masked_m_silu_mul_quant(
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
|
||||
y_s = torch.empty_strided(
|
||||
(E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
ys_shape,
|
||||
ys_strides,
|
||||
dtype=ys_dtype,
|
||||
device=y.device,
|
||||
)
|
||||
|
||||
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
|
||||
ceil_ue8m0 = quant_scale_fmt in [
|
||||
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
DeepGemmQuantScaleFMT.UE8M0,
|
||||
]
|
||||
|
||||
cuda_arch = current_platform.get_device_capability(
|
||||
device_id=y.device.index
|
||||
@@ -173,7 +200,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
|
||||
if cuda_arch >= 80:
|
||||
torch.ops._C.persistent_masked_m_silu_mul_quant(
|
||||
y, tokens_per_expert, y_q, y_s, use_ue8m0
|
||||
y, tokens_per_expert, y_q, y_s, ceil_ue8m0
|
||||
)
|
||||
else:
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
@@ -189,6 +216,10 @@ def persistent_masked_m_silu_mul_quant(
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
eps: float = 1e-10
|
||||
assert y_s.dtype == torch.float32, (
|
||||
"_silu_mul_fp8_quant_deep_gemm does"
|
||||
"not support {y_s.dtype} scales. Only torch.float32 supported."
|
||||
)
|
||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||
y,
|
||||
y_q,
|
||||
@@ -202,14 +233,14 @@ def persistent_masked_m_silu_mul_quant(
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
ys_strides[0],
|
||||
ys_strides[1],
|
||||
ys_strides[2],
|
||||
stride_cnt_e,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
is_deep_gemm_e8m0_used(),
|
||||
ceil_ue8m0,
|
||||
BLOCK=group_size,
|
||||
NUM_STAGES=4,
|
||||
num_warps=1,
|
||||
@@ -255,7 +286,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
|
||||
"""
|
||||
return current_platform.is_device_capability(100)
|
||||
return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
@@ -329,10 +360,17 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expected_m,
|
||||
)
|
||||
|
||||
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
|
||||
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
|
||||
workspace1, expert_num_tokens
|
||||
workspace1,
|
||||
expert_num_tokens,
|
||||
quant_scale_fmt=quant_scale_fmt,
|
||||
)
|
||||
|
||||
fp8_m_grouped_gemm_nt_masked(
|
||||
(a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m
|
||||
(a2q, a2q_scale),
|
||||
(w2, self.w2_scale),
|
||||
output,
|
||||
expert_num_tokens,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user