Silu v2 (#25074)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: elvircrn <elvircrn@gmail.com> Signed-off-by: Elvir Crnčević <elvircrn@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from math import log2
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -94,7 +93,7 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
def persistent_masked_m_silu_mul_quant(
|
||||
y: torch.Tensor, # (E, T, 2*H)
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens=16,
|
||||
@@ -103,9 +102,41 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
We launch a fixed grid of threads to accommodate CUDA graphs. Let `P2`
|
||||
be a parallelization factor for persistent_masked_m_silu_mul_quant over the
|
||||
hidden dimension.
|
||||
|
||||
Let `expert_offsets = [0] + [num_tokens.cumsum()]` and
|
||||
`total_tokens = expert_offsets[-1]`.
|
||||
persistent_masked_m_silu_mul_quant launches `total_tokens x P2` number of
|
||||
thread blocks. Each thread block contains `NUM_WARPS` warps.
|
||||
|
||||
Every thread block needs to find it's corresponding expert by warp-parallel scanning
|
||||
over the `expert_offsets` array.
|
||||
|
||||
The i-th warp in the first thread block processes
|
||||
`[i * warp_chunk_size, (i + 1) * warp_chunk_size]` groups
|
||||
sequentially, where `warp_chunk_size = ((H / GROUP_SIZE) / P2) / NUM_WARPS`,
|
||||
pipelining loads and computes.
|
||||
|
||||
The shared memory layout for 4 warps with a 2-stage pipeline for SiLU V2
|
||||
can is visualized like so:
|
||||
|
||||
stage0 stage1
|
||||
┌─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┐
|
||||
│gate0│up0│gate1│up1│gate2│up2│gate3│up3│gate0│up0│gate1│up1│gate2│up2│gate3│up3│
|
||||
└─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┘
|
||||
|
||||
with the main difference between V1 and V2 being the global load
|
||||
stride between warps, and between half-warps. Regarding the latter stride,
|
||||
we assign the first half warp of every warp for `gate` loads and the second
|
||||
half-warp to `up` loads.
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
Let NUM_WARPS be the number of warps in a single thread block and
|
||||
`GROUP_SIZE = 128` be the size of the quantization group.
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
@@ -133,30 +164,15 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
if E <= 16:
|
||||
max_empirical_parallelism = 64
|
||||
elif E <= 32:
|
||||
max_empirical_parallelism = 16
|
||||
else:
|
||||
max_empirical_parallelism = 4
|
||||
|
||||
# We never want to launch more than Tx number of threads
|
||||
# This computes the clip.
|
||||
num_parallel_tokens = max(
|
||||
1, min(max_empirical_parallelism, 2 ** int(log2(min(num_parallel_tokens, T))))
|
||||
)
|
||||
cuda_arch = current_platform.get_device_capability(
|
||||
device_id=y.device.index
|
||||
).to_int()
|
||||
|
||||
if cuda_arch >= 80:
|
||||
torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
y, tokens_per_expert, y_q, y_s, group_size, use_ue8m0, num_parallel_tokens
|
||||
torch.ops._C.persistent_masked_m_silu_mul_quant(
|
||||
y, tokens_per_expert, y_q, y_s, use_ue8m0
|
||||
)
|
||||
else:
|
||||
# Default to triton if not on cuda or if arch is too old
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# Static grid over experts and H-groups.
|
||||
@@ -166,16 +182,6 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
# desired scale strides (elements): (T*G, 1, T)
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
y_s = torch.empty_strided(
|
||||
(E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
device=y.device,
|
||||
)
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
@@ -313,7 +319,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expected_m,
|
||||
)
|
||||
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
|
||||
workspace1, expert_num_tokens
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user