[GPTOSS][DP/EP][Marlin] Enable GPTOSS Batched DP/EP using Marlin kernels (#25997)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
2ed8b6b3d0
commit
fb0571b077
@@ -50,7 +50,31 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# DeepEP low-latency kernels are compiled only for certain
|
||||
# specific hidden sizes.
|
||||
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168]
|
||||
# NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends
|
||||
# on it.
|
||||
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168, 8192]
|
||||
|
||||
@staticmethod
|
||||
def maybe_roundup_layer_hidden_size(hidden_size: int) -> int:
|
||||
# Round up hidden size to the closest supported hidden size.
|
||||
_supported_hs = DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES
|
||||
# Check sorted
|
||||
num_supported_hs = len(_supported_hs)
|
||||
assert all(
|
||||
[
|
||||
_supported_hs[i] < _supported_hs[i + 1]
|
||||
for i in range(num_supported_hs - 1)
|
||||
]
|
||||
)
|
||||
|
||||
for x in _supported_hs:
|
||||
if x >= hidden_size:
|
||||
return x
|
||||
|
||||
raise ValueError(
|
||||
f"Hidden Size {hidden_size} is greater than the "
|
||||
f"maximum supported hidden size {_supported_hs[-1]}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
"""Fused MoE utilities for GPTQ."""
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
|
||||
@@ -21,6 +24,153 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
|
||||
def _fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
bias1: torch.Tensor | None,
|
||||
bias2: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_topk: int,
|
||||
quant_type: ScalarType,
|
||||
apply_router_weight_on_input: bool,
|
||||
activation: str,
|
||||
expert_map: torch.Tensor | None,
|
||||
block_size_m: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
g_idx2: torch.Tensor | None = None,
|
||||
sort_indices1: torch.Tensor | None = None,
|
||||
sort_indices2: torch.Tensor | None = None,
|
||||
w1_zeros: torch.Tensor | None = None,
|
||||
w2_zeros: torch.Tensor | None = None,
|
||||
workspace: torch.Tensor | None = None,
|
||||
intermediate_cache13: torch.Tensor | None = None,
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
output: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.ndim == 2
|
||||
M, K = hidden_states.size()
|
||||
N = marlin_moe_intermediate_size(w1, w2)
|
||||
|
||||
if workspace is None:
|
||||
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||
|
||||
if intermediate_cache13 is None:
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * num_topk * max(2 * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if intermediate_cache2 is None:
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * num_topk, N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * num_topk, 2 * N))
|
||||
|
||||
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * num_topk, K))
|
||||
|
||||
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * num_topk, N))
|
||||
|
||||
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
|
||||
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
bias1,
|
||||
w1_scale,
|
||||
global_scale1,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=num_topk,
|
||||
mul_topk_weights=apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {activation}. "
|
||||
"Only silu and swigluoai activations are supported."
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = intermediate_cache3
|
||||
|
||||
if expert_map is not None:
|
||||
output.zero_()
|
||||
|
||||
output = ops.moe_wna16_marlin_gemm(
|
||||
intermediate_cache2,
|
||||
output,
|
||||
w2,
|
||||
bias2,
|
||||
w2_scale,
|
||||
global_scale2,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=not apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M * num_topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -62,23 +212,27 @@ def fused_marlin_moe(
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- w1_scale (torch.Tensor): Scale to be used for w1.
|
||||
- w2_scale (torch.Tensor): Scale to be used for w2.
|
||||
- gating_output (Optional[torch.Tensor]): The output of the gating
|
||||
- gating_output (torch.Tensor|None): The output of the gating
|
||||
operation (before softmax).
|
||||
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
||||
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
||||
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
||||
- g_idx1 (torch.Tensor|None): The first set of act_order indices.
|
||||
- g_idx2 (torch.Tensor|None): The second set of act_order indices.
|
||||
- sort_indices1 (torch.Tensor|None): The first act_order input
|
||||
permutation.
|
||||
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
||||
- sort_indices2 (torch.Tensor|None): The second act_order input
|
||||
permutation.
|
||||
- topk_weights (torch.Tensor): Top-k weights.
|
||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||
- w1_zeros (torch.Tensor|None): Optional zero points to be used for w1.
|
||||
- w2_zeros (torch.Tensor|None): Optional zero points to be used for w2.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
if inplace:
|
||||
assert output is None, "Conflicting request"
|
||||
|
||||
quant_type = ScalarType.from_id(quant_type_id)
|
||||
assert quant_type in [
|
||||
scalar_types.uint4,
|
||||
@@ -95,15 +249,15 @@ def fused_marlin_moe(
|
||||
]
|
||||
num_bits = 4 if quant_type in bit4_scalar_types else 8
|
||||
|
||||
M, K = hidden_states.size()
|
||||
E = w1.size(0)
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
# Check constraints.
|
||||
if gating_output is not None:
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch"
|
||||
)
|
||||
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[1] == w2.shape[2] // (num_bits // 2), (
|
||||
"Hidden size mismatch w2"
|
||||
)
|
||||
assert gating_output.size(0) == M, "Number of tokens mismatch"
|
||||
assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
|
||||
assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
@@ -111,11 +265,6 @@ def fused_marlin_moe(
|
||||
assert num_bits in [4, 8]
|
||||
assert topk_weights.dtype == torch.float32
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = marlin_moe_intermediate_size(w1, w2)
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
# M block size selection logic
|
||||
# TODO: tune this further for specific models
|
||||
for block_size_m in [8, 16, 32, 48, 64]:
|
||||
@@ -128,107 +277,38 @@ def fused_marlin_moe(
|
||||
topk_ids, block_size_m, global_num_experts, expert_map
|
||||
)
|
||||
|
||||
if workspace is None:
|
||||
workspace = marlin_make_workspace_new(hidden_states.device, 4)
|
||||
|
||||
if intermediate_cache2 is None:
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk, N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
if intermediate_cache13 is None:
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * topk * max(2 * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
intermediate_cache1 = _resize_cache(intermediate_cache13, (M * topk, 2 * N))
|
||||
intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K))
|
||||
intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N))
|
||||
|
||||
maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype)
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
|
||||
intermediate_cache1 = ops.moe_wna16_marlin_gemm(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
bias1,
|
||||
w1_scale,
|
||||
global_scale1,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
assert activation is not None
|
||||
moe_output = _fused_marlin_moe(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
bias1=bias1,
|
||||
bias2=bias2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
num_topk=topk,
|
||||
quant_type=quant_type,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
block_size_m=block_size_m,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=w1_zeros,
|
||||
w2_zeros=w2_zeros,
|
||||
workspace=workspace,
|
||||
intermediate_cache13=intermediate_cache13,
|
||||
intermediate_cache2=intermediate_cache2,
|
||||
output=None,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, 2 * N)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {activation}. "
|
||||
"Only silu and swigluoai activations are supported."
|
||||
)
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
intermediate_cache3 = ops.moe_wna16_marlin_gemm(
|
||||
intermediate_cache2,
|
||||
intermediate_cache3,
|
||||
w2,
|
||||
bias2,
|
||||
w2_scale,
|
||||
global_scale2,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=not apply_router_weight_on_input,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type=quant_type,
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
).view(-1, topk, K)
|
||||
|
||||
if output is None:
|
||||
@@ -237,16 +317,173 @@ def fused_marlin_moe(
|
||||
else:
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
|
||||
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
|
||||
|
||||
|
||||
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def batched_fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
bias1: torch.Tensor | None,
|
||||
bias2: torch.Tensor | None,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor | None,
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
activation: str | None = "silu",
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
g_idx2: torch.Tensor | None = None,
|
||||
sort_indices1: torch.Tensor | None = None,
|
||||
sort_indices2: torch.Tensor | None = None,
|
||||
w1_zeros: torch.Tensor | None = None,
|
||||
w2_zeros: torch.Tensor | None = None,
|
||||
workspace: torch.Tensor | None = None,
|
||||
intermediate_cache13: torch.Tensor | None = None,
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
output: torch.Tensor | None = None,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function massages the inputs so the batched hidden_states can be
|
||||
presented as a 2D contiguous tensor that could be used with
|
||||
_fused_marlin_moe.
|
||||
|
||||
Note that both batched_fused_marlin_moe and fused_marlin_moe ultimately
|
||||
use `ops.moe_wna16_marlin_gemm` for the gemm operation and
|
||||
`ops.moe_mna16_marlin_gemm` supports only 2D contiguous hidden_states.
|
||||
Note that the moe_align_block_size function indicates,
|
||||
- What rows of the A matrix (hidden_states) to access during the
|
||||
matmul, via sorted_ids output.
|
||||
- What expert_id to use for each block matmul, via expert_ids ouptut.
|
||||
|
||||
In the batched version, the tokens are already grouped/batched by experts
|
||||
they subscribe to. Due to this, we can represent the batched hidden_states
|
||||
tensor of shape [B, MAX_TOKENS_PER_BATCH, K] as a 2D tensor of shape,
|
||||
[B * MAX_TOKENS_PER_BATCH, K]. We may treat this a 2D contiguous tensor
|
||||
with topk=1 as each token (row in the tensor) subscribes to exactly one
|
||||
expert_id (which is the batch_id). With the expert_num_tokens tensor, that
|
||||
indicates how many tokens are actually valid in each batch, the
|
||||
batched_moe_align_block_size function constructs the sorted_ids and
|
||||
expert_ids tensors, so only relevant/valid rows of A (hidden_states)
|
||||
are accessed and are processed with the correct expert_ids.
|
||||
"""
|
||||
|
||||
assert hidden_states.ndim == 3, (
|
||||
f"hidden states must be batched. e.g. [B, MAX_TOKENS, K]."
|
||||
f"But got {hidden_states.size()}"
|
||||
)
|
||||
if inplace:
|
||||
assert output is None, "Conflicting request."
|
||||
|
||||
quant_type = ScalarType.from_id(quant_type_id)
|
||||
assert quant_type in [
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint8b128,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.float8_e4m3fn,
|
||||
scalar_types.float4_e2m1f,
|
||||
]
|
||||
|
||||
bit4_scalar_types = [
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.float4_e2m1f,
|
||||
]
|
||||
num_bits = 4 if quant_type in bit4_scalar_types else 8
|
||||
|
||||
B, BATCH_TOKENS_MAX, K = hidden_states.size()
|
||||
M = hidden_states.view(-1, K).size(0)
|
||||
E = w1.size(0)
|
||||
|
||||
# Check constraints.
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert expert_num_tokens.size(0) == E
|
||||
assert B == E, (
|
||||
"Batch must be as big as number of experts as the tokens"
|
||||
"are sorted into the batch/expert they belong to"
|
||||
)
|
||||
assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
|
||||
assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
# Technically, the tokens are already separated by their expert ids.
|
||||
# Hidden-States can just be squeezed to have just 2 dimensions,
|
||||
# [B * MAX_TOKENS, K] and top_k can be interpreted as just 1.
|
||||
topk = 1
|
||||
|
||||
# TODO(varun) : Choose a decent block size like in fused_marlin_moe
|
||||
block_size_m = 64
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = batched_moe_align_block_size(
|
||||
max_tokens_per_batch=BATCH_TOKENS_MAX,
|
||||
block_size=block_size_m,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
|
||||
if output is None and inplace:
|
||||
output = hidden_states
|
||||
|
||||
# TODO (varun): This can be avoided by plumbing the marlin kernel to
|
||||
# ignore topk_weights when topk_weights_ptr is a nullptr.
|
||||
topk_weights = torch.ones(
|
||||
(M, topk), device=hidden_states.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
assert activation is not None
|
||||
output = _fused_marlin_moe(
|
||||
hidden_states=hidden_states.view(-1, K),
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
bias1=bias1,
|
||||
bias2=bias2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
num_topk=topk,
|
||||
quant_type=quant_type,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
block_size_m=block_size_m,
|
||||
sorted_token_ids=sorted_token_ids,
|
||||
expert_ids=expert_ids,
|
||||
num_tokens_post_padded=num_tokens_post_padded,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=w1_zeros,
|
||||
w2_zeros=w2_zeros,
|
||||
workspace=workspace,
|
||||
intermediate_cache13=intermediate_cache13,
|
||||
intermediate_cache2=intermediate_cache2,
|
||||
output=output.view(-1, K) if output is not None else output,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
output = output.view(B, BATCH_TOKENS_MAX, K)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
# TODO (varun) : Enable activation quantization
|
||||
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
|
||||
super().__init__(quant_config)
|
||||
|
||||
@override
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@@ -274,6 +511,11 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
|
||||
class MarlinExperts(MarlinExpertsBase):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(quant_config)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -365,3 +607,90 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache13=workspace2,
|
||||
intermediate_cache2=workspace13,
|
||||
)
|
||||
|
||||
|
||||
class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
num_dispatchers = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2))
|
||||
workspace2 = (num_experts * max_num_tokens * num_dispatchers, N)
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert expert_tokens_meta is not None, "Num valid tokens per batch is required"
|
||||
return batched_fused_marlin_moe(
|
||||
hidden_states=hidden_states,
|
||||
expert_num_tokens=expert_tokens_meta.expert_num_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
bias1=self.w1_bias,
|
||||
bias2=self.w2_bias,
|
||||
w1_scale=self.w1_scale,
|
||||
w2_scale=self.w2_scale,
|
||||
gating_output=None,
|
||||
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
activation=activation,
|
||||
expert_map=expert_map,
|
||||
output=output,
|
||||
intermediate_cache13=workspace13,
|
||||
intermediate_cache2=workspace2,
|
||||
)
|
||||
|
||||
@@ -994,6 +994,11 @@ def maybe_roundup_hidden_size(
|
||||
hidden_size, act_dtype
|
||||
)
|
||||
|
||||
if moe_parallel_config.use_deepep_ll_kernels:
|
||||
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||
hidden_size
|
||||
)
|
||||
|
||||
# we are padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
|
||||
@@ -83,3 +83,92 @@ def moe_align_block_size(
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def batched_moe_align_block_size(
|
||||
max_tokens_per_batch: int, block_size: int, expert_num_tokens: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Given num_batches, max_tokens_per_batch, block_size and the number of
|
||||
valid-tokens in each batch, prepare sorted_token_ids, expert_ids and
|
||||
num_tokens_post_pad. sorted_token_ids, expert_ids and num_tokens_post_pad
|
||||
have the same semantics as in moe_align_block_size.
|
||||
|
||||
This function is intended to be a drop in replacement for
|
||||
moe_align_batch_size for the batched case.
|
||||
|
||||
Parameters:
|
||||
- max_tokens_per_batch (int): Number of tokens in each batch (both
|
||||
valid and invalid).
|
||||
- block_size (int): block_size to align the data to.
|
||||
- expert_num_tokens (torch.Tensor): expert_num_tokens[i], indicates
|
||||
the number of valid tokens in batch i.
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids (torch.Tensor): Torch tensor of size
|
||||
(num_batches * max_tokens_per_batch) indicating the token indices for
|
||||
that block.
|
||||
- expert_ids (torch.Tensor): Torch tensor of size
|
||||
ceil((num_batches * max_tokens_per_batch) / block_size) indicating
|
||||
what expert to use for each block.
|
||||
- num_tokens_post_pad (torch.Tensor): Torch tensor of size 1
|
||||
indicating the number of valid blocks with actual data to
|
||||
process. This is represented in terms of num tokens.
|
||||
Example:
|
||||
Let num_batches=5, max_tokens_per_batch=8, block_size=4, and
|
||||
expert_num_tokens=[2, 3, 0, 6, 8]. This expert_num_tokens tensor
|
||||
indicates that,
|
||||
- The first 2 tokens in the 0th batch are valid and the rest 6 are
|
||||
invalid (i.e. in the 2D hidden_states tensor of shape,
|
||||
[num_batches * max_tokens_per_batch, K], indices 0, 1 are valid)
|
||||
- The first 3 tokens in the 1st batch are valid. i.e. indices 8, 9, 10
|
||||
- 0 tokens in the 2nd batch are valid
|
||||
- first 6 tokens in the 3rd batch are valid. i.e. indices,
|
||||
24, 25, 26, 27, 28, 29
|
||||
- so on ...
|
||||
|
||||
In this case,
|
||||
sorted_token_ids will be [0, 1, 40, 40,
|
||||
8, 9, 10, 40,
|
||||
24, 25, 26, 27,
|
||||
28, 29, 40, 40,
|
||||
32, 33, 34, 35,
|
||||
36, 37, 38, 39,
|
||||
40, 40, 40, 40,
|
||||
(rest all 40, 40, 40, 40)
|
||||
...]
|
||||
Here, 40 represents an invalid index. as there is no token index 40.
|
||||
The gemm kernel using this sorted_token_ids is expected to skip the
|
||||
gemm computation when it encounters this invalid index.
|
||||
|
||||
expert_ids will be [0, 1, 3, 3, 4, 5, 5, -1, -1, (rest all -1) ...]
|
||||
Here, -1 represents an invalid expert. The gemm kernel using this
|
||||
expert_ids is expected to skip the gemm computation when it encounters
|
||||
an expert of id -1.
|
||||
|
||||
num_tokens_post_pad will be 24 as sorted_token_ids has valid entries
|
||||
until 24.
|
||||
"""
|
||||
|
||||
B = expert_num_tokens.size(0)
|
||||
device = expert_num_tokens.device
|
||||
|
||||
# Round up so each batch can be split to blocks evenly.
|
||||
max_num_tokens_padded = B * round_up(max_tokens_per_batch, block_size)
|
||||
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
|
||||
assert max_num_tokens_padded % block_size == 0
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=device)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=device)
|
||||
|
||||
ops.batched_moe_align_block_size(
|
||||
max_tokens_per_batch,
|
||||
block_size,
|
||||
expert_num_tokens,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
BatchedMarlinExperts,
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
@@ -797,9 +798,19 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
prepare_finalize.activation_format
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP"
|
||||
)
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
assert self.moe_quant_config is not None
|
||||
return BatchedMarlinExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Incompatible Mxfp4 backend for EP batched experts format"
|
||||
)
|
||||
else:
|
||||
assert self.moe_quant_config is not None
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user