[Kernel] add triton fused moe kernel for gptq/awq (#12185)
This commit is contained in:
@@ -19,6 +19,206 @@ from vllm.utils import direct_register_custom_op
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel_gptq_awq(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
b_scale_ptr,
|
||||
b_zp_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bze,
|
||||
stride_bzk,
|
||||
stride_bzn,
|
||||
block_k_diviable: tl.constexpr,
|
||||
group_size: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
has_zp: tl.constexpr,
|
||||
use_int4_w4a16: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse.
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
|
||||
tl.int64)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
|
||||
if use_int4_w4a16:
|
||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
|
||||
b_shifter = (offs_k[:, None] % 2) * 4
|
||||
elif use_int8_w8a16:
|
||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
||||
|
||||
if not has_zp and use_int4_w4a16:
|
||||
b_zp_num = 8
|
||||
if not has_zp and use_int8_w8a16:
|
||||
b_zp_num = 128
|
||||
elif has_zp and use_int4_w4a16:
|
||||
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
|
||||
if not block_k_diviable:
|
||||
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
||||
k_other = 0.0
|
||||
else:
|
||||
k_mask = None
|
||||
k_other = None
|
||||
|
||||
a = tl.load(a_ptrs,
|
||||
mask=token_mask[:, None] &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs)
|
||||
if use_int4_w4a16:
|
||||
b = (b >> b_shifter) & 0xF
|
||||
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
||||
offs_bn[None, :] * stride_bsn + \
|
||||
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
||||
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
||||
b_scale = b_scale.to(tl.float32)
|
||||
|
||||
if has_zp and use_int4_w4a16:
|
||||
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
||||
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
||||
(offs_bn[None, :] // 2) * stride_bzn + \
|
||||
offs_k_true * stride_bzk
|
||||
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
||||
b_zp = ((b_zp >> b_zp_shifter) & 0xF)
|
||||
b_zp = b_zp.to(tl.float32)
|
||||
elif has_zp and use_int8_w8a16:
|
||||
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
||||
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
||||
offs_bn[None, :] * stride_bzn + \
|
||||
offs_k_true * stride_bzk
|
||||
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
||||
b_zp = b_zp.to(tl.float32)
|
||||
|
||||
# We accumulate along the K dimension.
|
||||
if has_zp:
|
||||
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
||||
else:
|
||||
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
if use_int4_w4a16:
|
||||
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
||||
else:
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel(
|
||||
# Pointers to matrices
|
||||
@@ -266,6 +466,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
@@ -277,6 +478,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@@ -292,50 +494,108 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a16:
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
EM = sorted_token_ids.shape[0]
|
||||
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique, so
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.shape[0],
|
||||
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
B.shape[2],
|
||||
sorted_token_ids.shape[0],
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||
block_shape is not None and block_shape[1] > 0:
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
EM,
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0),
|
||||
B_scale.stride(2),
|
||||
B_scale.stride(1),
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
has_zp=B_zp is not None,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
|
||||
else:
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
EM,
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0)
|
||||
if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1)
|
||||
if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0)
|
||||
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2)
|
||||
if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1)
|
||||
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||
@@ -432,7 +692,7 @@ def try_get_optimal_moe_config(
|
||||
# NOTE: For block-wise quant,
|
||||
# BLOCK_K must be divisible by block_shape[1]
|
||||
# BLOCK_N and BLOCK_M has no requirements
|
||||
if block_shape is not None:
|
||||
if block_shape is not None and block_shape[0] != 0:
|
||||
config["BLOCK_SIZE_N"] = block_shape[0]
|
||||
config["BLOCK_SIZE_K"] = block_shape[1]
|
||||
return config
|
||||
@@ -531,12 +791,15 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False):
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w8a16"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
@@ -551,14 +814,17 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@@ -569,8 +835,11 @@ def inplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
@@ -593,14 +862,18 @@ def outplace_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, use_fp8_w8a8, use_int8_w8a16, w1_scale,
|
||||
w2_scale, a1_scale, a2_scale, block_shape)
|
||||
False, use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@@ -611,8 +884,11 @@ def outplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
@@ -635,8 +911,11 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
@@ -644,16 +923,15 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
||||
topk_weights, topk_ids,
|
||||
use_fp8_w8a8, use_int8_w8a16,
|
||||
w1_scale, w2_scale, a1_scale,
|
||||
use_int4_w4a16, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
return hidden_states
|
||||
else:
|
||||
return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2,
|
||||
topk_weights, topk_ids,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16, w1_scale,
|
||||
w2_scale, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
return torch.ops.vllm.outplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
@@ -664,13 +942,21 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||
2], "Hidden size mismatch"
|
||||
else:
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
@@ -687,6 +973,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
@@ -755,6 +1042,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
@@ -766,6 +1054,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
block_shape=block_shape)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
@@ -776,6 +1065,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
@@ -787,6 +1077,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
block_shape=block_shape)
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
@@ -808,8 +1099,11 @@ def fused_moe(
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
@@ -834,8 +1128,12 @@ def fused_moe(
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
@@ -873,8 +1171,11 @@ def fused_moe(
|
||||
inplace=inplace,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
Reference in New Issue
Block a user