Expert Parallelism (EP) Support for DeepSeek V2 (#12583)
This commit is contained in:
@@ -20,6 +20,18 @@ from vllm.utils import direct_register_custom_op
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
||||
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||
compute_type):
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel_gptq_awq(
|
||||
# Pointers to matrices
|
||||
@@ -120,17 +132,26 @@ def fused_moe_kernel_gptq_awq(
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
if off_experts == -1:
|
||||
# -----------------------------------------------------------
|
||||
# Write back zeros to the output when the expert is not
|
||||
# in the current expert parallel rank.
|
||||
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
|
||||
offs_token, token_mask, BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N, compute_type)
|
||||
return
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
|
||||
if use_int4_w4a16:
|
||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
|
||||
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \
|
||||
stride_bn
|
||||
b_shifter = (offs_k[:, None] % 2) * 4
|
||||
elif use_int8_w8a16:
|
||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||
@@ -170,7 +191,8 @@ def fused_moe_kernel_gptq_awq(
|
||||
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
||||
offs_bn[None, :] * stride_bsn + \
|
||||
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
||||
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \
|
||||
stride_bsk
|
||||
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
||||
b_scale = b_scale.to(tl.float32)
|
||||
|
||||
@@ -319,13 +341,22 @@ def fused_moe_kernel(
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
if off_experts == -1:
|
||||
# -----------------------------------------------------------
|
||||
# Write back zeros to the output when the expert is not
|
||||
# in the current expert parallel rank.
|
||||
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
|
||||
offs_token, token_mask, BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N, compute_type)
|
||||
return
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
if use_int8_w8a16:
|
||||
@@ -349,7 +380,6 @@ def fused_moe_kernel(
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
@@ -544,8 +574,11 @@ def moe_align_block_size_triton(
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor, block_size: int,
|
||||
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
@@ -555,6 +588,10 @@ def moe_align_block_size(
|
||||
top-k expert indices for each token.
|
||||
- block_size: The block size used in block matrix multiplication.
|
||||
- num_experts: The total number of experts.
|
||||
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
@@ -589,7 +626,9 @@ def moe_align_block_size(
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||
# mapping global expert ids to local expert ids in expert parallelism.
|
||||
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
@@ -618,6 +657,9 @@ def moe_align_block_size(
|
||||
else:
|
||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||
expert_ids, num_tokens_post_pad)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
@@ -1001,6 +1043,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@@ -1009,8 +1053,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@@ -1022,6 +1067,8 @@ def inplace_fused_experts_fake(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@@ -1049,6 +1096,8 @@ def outplace_fused_experts(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@@ -1058,8 +1107,9 @@ def outplace_fused_experts(
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@@ -1071,6 +1121,8 @@ def outplace_fused_experts_fake(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@@ -1098,26 +1150,27 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if inplace:
|
||||
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
||||
topk_weights, topk_ids,
|
||||
use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
torch.ops.vllm.inplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
return hidden_states
|
||||
else:
|
||||
return torch.ops.vllm.outplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
|
||||
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
@@ -1129,6 +1182,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@@ -1153,6 +1208,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
@@ -1166,20 +1224,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
config = get_config_func(M)
|
||||
|
||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
||||
intermediate_cache1 = torch.empty((M, top_k_num, N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
||||
intermediate_cache2 = torch.empty((M * top_k_num, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
||||
intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
@@ -1221,7 +1279,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
|
||||
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
|
||||
global_num_experts, expert_map))
|
||||
|
||||
invoke_fused_moe_kernel(curr_hidden_states,
|
||||
w1,
|
||||
@@ -1235,7 +1294,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
@@ -1286,6 +1345,8 @@ def fused_moe(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
@@ -1320,6 +1381,11 @@ def fused_moe(
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
@@ -1334,8 +1400,6 @@ def fused_moe(
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
@@ -1358,6 +1422,8 @@ def fused_moe(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
|
||||
Reference in New Issue
Block a user