Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused MoE kernel."""
|
||||
import functools
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
from math import prod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -15,7 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils import direct_register_custom_op, round_up
|
||||
|
||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
rocm_aiter_fused_experts,
|
||||
@@ -23,6 +25,8 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
||||
@@ -581,7 +585,8 @@ def moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor = None
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
pad_sorted_ids: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
@@ -596,6 +601,8 @@ def moe_align_block_size(
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
@@ -625,6 +632,8 @@ def moe_align_block_size(
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
@@ -667,6 +676,59 @@ def moe_align_block_size(
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor]) -> bool:
|
||||
"""
|
||||
Check if the given problem size is supported by the DeepGemm grouped
|
||||
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||
"""
|
||||
if not has_deep_gemm:
|
||||
return False
|
||||
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
# Expert maps not supported yet.
|
||||
if expert_map is not None:
|
||||
return False
|
||||
|
||||
align = dg.get_m_alignment_for_contiguous_layout()
|
||||
M = hidden_states.shape[0]
|
||||
_, K, N = w2.shape
|
||||
|
||||
# For now, disable DeepGemm for small N until better permute/unpermute
|
||||
# ops are available.
|
||||
if N <= 512:
|
||||
return False
|
||||
|
||||
if align > M or N % align != 0 or K % align != 0:
|
||||
return False
|
||||
|
||||
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
||||
and w2.is_contiguous())
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[List[int]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
@@ -691,15 +753,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
|
||||
if use_fp8_w8a8:
|
||||
assert B_scale is not None
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0])
|
||||
== B_scale.shape[-2])
|
||||
assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1])
|
||||
== B_scale.shape[-1])
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
@@ -1066,7 +1124,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
|
||||
@@ -1098,14 +1156,16 @@ def fused_topk(
|
||||
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def grouped_topk(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None):
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
@@ -1154,10 +1214,11 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False):
|
||||
def get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False) -> Optional[str]:
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
@@ -1318,26 +1379,123 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
block_shape: Optional[List[int]] = None,
|
||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||
if (allow_deep_gemm and use_fp8_w8a8
|
||||
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
else:
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
A permutation routine that works on fp8 types.
|
||||
"""
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def _moe_permute(
|
||||
curr_hidden_states: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
curr_topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
top_k_num: int,
|
||||
block_m: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
"""
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids,
|
||||
block_m,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
pad_sorted_ids=True))
|
||||
|
||||
inv_perm: Optional[torch.Tensor] = None
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||
|
||||
# Permute according to sorted token ids.
|
||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||
sorted_token_ids // top_k_num)
|
||||
|
||||
if a1q_scale is not None:
|
||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||
|
||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm)
|
||||
|
||||
|
||||
def _moe_unpermute_and_reduce(
|
||||
out: torch.Tensor,
|
||||
curr_hidden: torch.Tensor,
|
||||
inv_perm: Optional[torch.Tensor],
|
||||
topk: int,
|
||||
K: int,
|
||||
topk_weight: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Unpermute the final result and apply topk_weights, then perform the final
|
||||
reduction on the hidden states.
|
||||
"""
|
||||
M = topk_weight.shape[0]
|
||||
curr_hidden = curr_hidden[inv_perm, ...]
|
||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||
ops.moe_sum(curr_hidden, out)
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
used to resize the intermediate fused_moe caches.
|
||||
"""
|
||||
assert prod(v) <= x.numel()
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
@@ -1376,6 +1534,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
@@ -1401,13 +1560,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]),
|
||||
cache13 = torch.empty(M * top_k_num * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache1 = cache13[:M * top_k_num * N].view(
|
||||
(M, topk_ids.shape[1], N))
|
||||
intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view(
|
||||
(M, topk_ids.shape[1], w2.shape[1]))
|
||||
intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N)
|
||||
intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K)
|
||||
|
||||
# This needs separate memory since it's used concurrently with cache1
|
||||
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
||||
@@ -1452,14 +1609,23 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(
|
||||
curr_hidden_states, a1_scale, block_shape)
|
||||
else:
|
||||
qcurr_hidden_states = curr_hidden_states
|
||||
a1q_scale = a1_scale
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
|
||||
global_num_experts, expert_map))
|
||||
|
||||
invoke_fused_moe_kernel(curr_hidden_states,
|
||||
invoke_fused_moe_kernel(qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
@@ -1485,10 +1651,19 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
invoke_fused_moe_kernel(intermediate_cache2,
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
if use_fp8_w8a8:
|
||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
intermediate_cache2, a2_scale, block_shape)
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
@@ -1617,6 +1792,193 @@ def fused_moe(
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with DeepGemm
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
assert expert_map is None, "Expert maps not supported yet"
|
||||
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
assert w2.dtype == torch.float8_e4m3fn
|
||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
||||
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
block_m = dg.get_m_alignment_for_contiguous_layout()
|
||||
block_shape = [block_m, block_m]
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
||||
# case these calls will be nops. Otherwise, they'll be performed every
|
||||
# time the layer is executed.
|
||||
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
||||
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
||||
|
||||
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
|
||||
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
cache13 = torch.empty(M_sum * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
intermediate_cache1 = cache13[:M_sum * N].view(M_sum, N)
|
||||
intermediate_cache2 = torch.empty((M_sum, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache3 = cache13[:M_sum * K].view(M_sum, K)
|
||||
|
||||
for chunk in range(num_chunks):
|
||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
num_tokens))
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
||||
a1_scale, block_shape)
|
||||
|
||||
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
||||
curr_topk_ids, global_num_experts,
|
||||
expert_map, top_k_num, block_m)
|
||||
|
||||
# Adjust the intermediate cache size and config for the last chunk.
|
||||
# Note that in most cases we only have one chunk so the cache size
|
||||
# and config are already set correctly and do not need to be adjusted.
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
curr_M = sorted_token_ids.numel()
|
||||
intermediate_cache1 = _resize_cache(intermediate_cache1,
|
||||
(curr_M, N))
|
||||
intermediate_cache2 = _resize_cache(intermediate_cache2,
|
||||
(curr_M, N // 2))
|
||||
intermediate_cache3 = _resize_cache(intermediate_cache3,
|
||||
(curr_M, K))
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qcurr_hidden_states, a1q_scale), (w1, w1_scale),
|
||||
intermediate_cache1, expert_ids)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
intermediate_cache2, a2_scale, block_shape)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qintermediate_cache2, a2q_scale), (w2, w2_scale),
|
||||
intermediate_cache3, expert_ids)
|
||||
|
||||
_moe_unpermute_and_reduce(
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
intermediate_cache3.view(*intermediate_cache3.shape), inv_perm,
|
||||
top_k_num, K, curr_topk_weights)
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||
def cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user