[MoE Refactor] Split invoke_fused_moe_kernel (#31050)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -541,7 +541,263 @@ def fused_moe_kernel(
|
|||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_kernel(
|
# NOTE(zyongye): we can remove all the wna16 kernel
|
||||||
|
# once we drop off sm75 support
|
||||||
|
def invoke_fused_moe_wna16_cuda_kernel(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
C: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor | None,
|
||||||
|
B_zp: torch.Tensor | None,
|
||||||
|
topk_weights: torch.Tensor | None,
|
||||||
|
sorted_token_ids: torch.Tensor,
|
||||||
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_padded: torch.Tensor,
|
||||||
|
mul_routed_weight: bool,
|
||||||
|
top_k: int,
|
||||||
|
config: dict[str, Any],
|
||||||
|
block_shape: list[int],
|
||||||
|
):
|
||||||
|
assert B_scale is not None and B_scale.ndim == 3
|
||||||
|
assert B_zp is None or B_zp.ndim == 3
|
||||||
|
assert block_shape is None or block_shape[0] == 0
|
||||||
|
|
||||||
|
M = A.size(0)
|
||||||
|
num_tokens = M * top_k
|
||||||
|
bit = 4
|
||||||
|
|
||||||
|
config = config.copy()
|
||||||
|
config.update(
|
||||||
|
get_moe_wna16_block_config(
|
||||||
|
config=config,
|
||||||
|
use_moe_wna16_cuda=True,
|
||||||
|
num_valid_tokens=num_tokens,
|
||||||
|
size_k=A.size(1),
|
||||||
|
size_n=B.size(1),
|
||||||
|
num_experts=B.size(1),
|
||||||
|
group_size=block_shape[1],
|
||||||
|
real_top_k=top_k,
|
||||||
|
block_size_m=config["BLOCK_SIZE_M"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.moe_wna16_gemm(
|
||||||
|
A,
|
||||||
|
C,
|
||||||
|
B,
|
||||||
|
B_scale,
|
||||||
|
B_zp,
|
||||||
|
topk_weights if mul_routed_weight else None,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
top_k,
|
||||||
|
config["BLOCK_SIZE_M"],
|
||||||
|
config["BLOCK_SIZE_N"],
|
||||||
|
config["BLOCK_SIZE_K"],
|
||||||
|
bit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(zyongye): we can remove all the wna16 kernel
|
||||||
|
# once we drop off sm75 support
|
||||||
|
def invoke_fused_moe_wna16_triton_kernel(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
C: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor | None,
|
||||||
|
B_zp: torch.Tensor | None,
|
||||||
|
topk_weights: torch.Tensor | None,
|
||||||
|
sorted_token_ids: torch.Tensor,
|
||||||
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_padded: torch.Tensor,
|
||||||
|
mul_routed_weight: bool,
|
||||||
|
top_k: int,
|
||||||
|
config: dict[str, Any],
|
||||||
|
compute_type: tl.dtype,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
use_int4_w4a16: bool,
|
||||||
|
block_shape: list[int],
|
||||||
|
):
|
||||||
|
assert B_scale is not None and B_scale.ndim == 3
|
||||||
|
assert B_zp is None or B_zp.ndim == 3
|
||||||
|
assert block_shape is None or block_shape[0] == 0
|
||||||
|
|
||||||
|
M = A.size(0)
|
||||||
|
num_tokens = M * top_k
|
||||||
|
|
||||||
|
EM = sorted_token_ids.size(0)
|
||||||
|
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||||
|
# optimize for small batch_size.
|
||||||
|
# We assume that top_ids of each token is unique,
|
||||||
|
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||||
|
# and we can skip some invalid blocks.
|
||||||
|
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
config = config.copy()
|
||||||
|
config.update(
|
||||||
|
get_moe_wna16_block_config(
|
||||||
|
config=config,
|
||||||
|
use_moe_wna16_cuda=False,
|
||||||
|
num_valid_tokens=num_tokens,
|
||||||
|
size_k=A.size(1),
|
||||||
|
size_n=B.size(1),
|
||||||
|
num_experts=B.size(1),
|
||||||
|
group_size=block_shape[1],
|
||||||
|
real_top_k=top_k,
|
||||||
|
block_size_m=config["BLOCK_SIZE_M"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fused_moe_kernel_gptq_awq[grid](
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
B_scale,
|
||||||
|
B_zp,
|
||||||
|
topk_weights,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
B.size(1),
|
||||||
|
A.size(1),
|
||||||
|
EM,
|
||||||
|
num_tokens,
|
||||||
|
A.stride(0),
|
||||||
|
A.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
B.stride(2),
|
||||||
|
B.stride(1),
|
||||||
|
C.stride(1),
|
||||||
|
C.stride(2),
|
||||||
|
B_scale.stride(0),
|
||||||
|
B_scale.stride(2),
|
||||||
|
B_scale.stride(1),
|
||||||
|
B_zp.stride(0) if B_zp is not None else 0,
|
||||||
|
B_zp.stride(2) if B_zp is not None else 0,
|
||||||
|
B_zp.stride(1) if B_zp is not None else 0,
|
||||||
|
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||||
|
group_size=block_shape[1],
|
||||||
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
|
top_k=top_k,
|
||||||
|
compute_type=compute_type,
|
||||||
|
has_zp=B_zp is not None,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_fused_moe_triton_kernel(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
C: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor | None,
|
||||||
|
B_scale: torch.Tensor | None,
|
||||||
|
topk_weights: torch.Tensor | None,
|
||||||
|
sorted_token_ids: torch.Tensor,
|
||||||
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_padded: torch.Tensor,
|
||||||
|
mul_routed_weight: bool,
|
||||||
|
top_k: int,
|
||||||
|
config: dict[str, Any],
|
||||||
|
compute_type: tl.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
use_int4_w4a16: bool,
|
||||||
|
per_channel_quant: bool,
|
||||||
|
block_shape: list[int] | None = None,
|
||||||
|
B_bias: torch.Tensor | None = None,
|
||||||
|
):
|
||||||
|
assert topk_weights is not None or not mul_routed_weight
|
||||||
|
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||||
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
|
if use_fp8_w8a8 or use_int8_w8a8:
|
||||||
|
assert B_scale is not None
|
||||||
|
assert block_shape is None or triton.cdiv(
|
||||||
|
B.size(-2), block_shape[0]
|
||||||
|
) == B_scale.size(-2)
|
||||||
|
assert block_shape is None or triton.cdiv(
|
||||||
|
B.size(-1), block_shape[1]
|
||||||
|
) == B_scale.size(-1)
|
||||||
|
elif use_int8_w8a16 or use_int4_w4a16:
|
||||||
|
assert B_scale is not None
|
||||||
|
assert block_shape is None or block_shape[0] == 0
|
||||||
|
else:
|
||||||
|
assert A_scale is None
|
||||||
|
assert B_scale is None
|
||||||
|
|
||||||
|
M = A.size(0)
|
||||||
|
num_tokens = M * top_k
|
||||||
|
|
||||||
|
EM = sorted_token_ids.size(0)
|
||||||
|
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||||
|
# optimize for small batch_size.
|
||||||
|
# We assume that top_ids of each token is unique,
|
||||||
|
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||||
|
# and we can skip some invalid blocks.
|
||||||
|
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
HAS_BIAS = B_bias is not None
|
||||||
|
|
||||||
|
config = config.copy()
|
||||||
|
config["SPLIT_K"] = 1
|
||||||
|
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||||
|
if block_shape is not None:
|
||||||
|
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||||
|
fused_moe_kernel[grid](
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
B_bias,
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
topk_weights,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
B.size(1),
|
||||||
|
B.size(2),
|
||||||
|
EM,
|
||||||
|
num_tokens,
|
||||||
|
A.stride(0),
|
||||||
|
A.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
B.stride(2),
|
||||||
|
B.stride(1),
|
||||||
|
C.stride(1),
|
||||||
|
C.stride(2),
|
||||||
|
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||||
|
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||||
|
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||||
|
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||||
|
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||||
|
B_bias.stride(0) if B_bias is not None else 0,
|
||||||
|
B_bias.stride(1) if B_bias is not None else 0,
|
||||||
|
0 if block_shape is None else block_shape[0],
|
||||||
|
0 if block_shape is None else block_shape[1],
|
||||||
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
|
top_k=top_k,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
per_channel_quant=per_channel_quant,
|
||||||
|
HAS_BIAS=HAS_BIAS,
|
||||||
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_fused_moe_kernel(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
C: torch.Tensor,
|
C: torch.Tensor,
|
||||||
@@ -568,44 +824,13 @@ def invoke_fused_moe_kernel(
|
|||||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
if use_fp8_w8a8 or use_int8_w8a8:
|
|
||||||
assert B_scale is not None
|
|
||||||
assert block_shape is None or triton.cdiv(
|
|
||||||
B.size(-2), block_shape[0]
|
|
||||||
) == B_scale.size(-2)
|
|
||||||
assert block_shape is None or triton.cdiv(
|
|
||||||
B.size(-1), block_shape[1]
|
|
||||||
) == B_scale.size(-1)
|
|
||||||
|
|
||||||
elif use_int8_w8a16 or use_int4_w4a16:
|
|
||||||
assert B_scale is not None
|
|
||||||
assert block_shape is None or block_shape[0] == 0
|
|
||||||
else:
|
|
||||||
assert A_scale is None
|
|
||||||
assert B_scale is None
|
|
||||||
|
|
||||||
M = A.size(0)
|
M = A.size(0)
|
||||||
num_tokens = M * top_k
|
num_tokens = M * top_k
|
||||||
|
|
||||||
EM = sorted_token_ids.size(0)
|
if (use_int8_w8a16 or use_int4_w4a16) and (
|
||||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
block_shape is not None and block_shape[1] > 0
|
||||||
# optimize for small batch_size.
|
|
||||||
# We assume that top_ids of each token is unique,
|
|
||||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
|
||||||
# and we can skip some invalid blocks.
|
|
||||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
|
||||||
grid = lambda META: (
|
|
||||||
triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
|
||||||
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
|
||||||
)
|
|
||||||
HAS_BIAS = B_bias is not None
|
|
||||||
if (
|
|
||||||
(use_int8_w8a16 or use_int4_w4a16)
|
|
||||||
and block_shape is not None
|
|
||||||
and block_shape[1] > 0
|
|
||||||
):
|
):
|
||||||
assert B_scale is not None and B_scale.ndim == 3
|
assert B_bias is None
|
||||||
assert B_zp is None or B_zp.ndim == 3
|
|
||||||
|
|
||||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||||
num_valid_tokens=num_tokens,
|
num_valid_tokens=num_tokens,
|
||||||
@@ -613,41 +838,25 @@ def invoke_fused_moe_kernel(
|
|||||||
num_experts=B.size(0),
|
num_experts=B.size(0),
|
||||||
bit=4 if use_int4_w4a16 else 8,
|
bit=4 if use_int4_w4a16 else 8,
|
||||||
)
|
)
|
||||||
config = config.copy()
|
|
||||||
config.update(
|
|
||||||
get_moe_wna16_block_config(
|
|
||||||
config=config,
|
|
||||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
|
||||||
num_valid_tokens=num_tokens,
|
|
||||||
size_k=A.size(1),
|
|
||||||
size_n=B.size(1),
|
|
||||||
num_experts=B.size(1),
|
|
||||||
group_size=block_shape[1],
|
|
||||||
real_top_k=top_k,
|
|
||||||
block_size_m=config["BLOCK_SIZE_M"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_moe_wna16_cuda:
|
if use_moe_wna16_cuda:
|
||||||
bit = 4 if use_int4_w4a16 else 8
|
invoke_fused_moe_wna16_cuda_kernel(
|
||||||
ops.moe_wna16_gemm(
|
|
||||||
A,
|
A,
|
||||||
C,
|
|
||||||
B,
|
B,
|
||||||
|
C,
|
||||||
B_scale,
|
B_scale,
|
||||||
B_zp,
|
B_zp,
|
||||||
topk_weights if mul_routed_weight else None,
|
topk_weights,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
|
mul_routed_weight,
|
||||||
top_k,
|
top_k,
|
||||||
config["BLOCK_SIZE_M"],
|
config,
|
||||||
config["BLOCK_SIZE_N"],
|
block_shape,
|
||||||
config["BLOCK_SIZE_K"],
|
|
||||||
bit,
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
fused_moe_kernel_gptq_awq[grid](
|
invoke_fused_moe_wna16_triton_kernel(
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
@@ -657,80 +866,37 @@ def invoke_fused_moe_kernel(
|
|||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
B.size(1),
|
mul_routed_weight,
|
||||||
A.size(1),
|
top_k,
|
||||||
EM,
|
config,
|
||||||
num_tokens,
|
compute_type,
|
||||||
A.stride(0),
|
use_int8_w8a16,
|
||||||
A.stride(1),
|
use_int4_w4a16,
|
||||||
B.stride(0),
|
block_shape,
|
||||||
B.stride(2),
|
|
||||||
B.stride(1),
|
|
||||||
C.stride(1),
|
|
||||||
C.stride(2),
|
|
||||||
B_scale.stride(0),
|
|
||||||
B_scale.stride(2),
|
|
||||||
B_scale.stride(1),
|
|
||||||
B_zp.stride(0) if B_zp is not None else 0,
|
|
||||||
B_zp.stride(2) if B_zp is not None else 0,
|
|
||||||
B_zp.stride(1) if B_zp is not None else 0,
|
|
||||||
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
|
||||||
group_size=block_shape[1],
|
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
|
||||||
top_k=top_k,
|
|
||||||
compute_type=compute_type,
|
|
||||||
has_zp=B_zp is not None,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
**config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
config = config.copy()
|
invoke_fused_moe_triton_kernel(
|
||||||
config["SPLIT_K"] = 1
|
|
||||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
|
||||||
if block_shape is not None:
|
|
||||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
|
||||||
fused_moe_kernel[grid](
|
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
B_bias,
|
|
||||||
A_scale,
|
A_scale,
|
||||||
B_scale,
|
B_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
B.size(1),
|
mul_routed_weight,
|
||||||
B.size(2),
|
top_k,
|
||||||
EM,
|
config,
|
||||||
num_tokens,
|
compute_type,
|
||||||
A.stride(0),
|
use_fp8_w8a8,
|
||||||
A.stride(1),
|
use_int8_w8a8,
|
||||||
B.stride(0),
|
use_int8_w8a16,
|
||||||
B.stride(2),
|
use_int4_w4a16,
|
||||||
B.stride(1),
|
per_channel_quant,
|
||||||
C.stride(1),
|
block_shape,
|
||||||
C.stride(2),
|
B_bias,
|
||||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
|
||||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
|
||||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
|
||||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
|
||||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
|
||||||
B_bias.stride(0) if B_bias is not None else 0,
|
|
||||||
B_bias.stride(1) if B_bias is not None else 0,
|
|
||||||
0 if block_shape is None else block_shape[0],
|
|
||||||
0 if block_shape is None else block_shape[1],
|
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
|
||||||
top_k=top_k,
|
|
||||||
compute_type=compute_type,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
per_channel_quant=per_channel_quant,
|
|
||||||
HAS_BIAS=HAS_BIAS,
|
|
||||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
||||||
**config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1997,7 +2163,7 @@ def fused_experts_impl(
|
|||||||
ignore_invalid_experts=True,
|
ignore_invalid_experts=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
dispatch_fused_moe_kernel(
|
||||||
qcurr_hidden_states,
|
qcurr_hidden_states,
|
||||||
w1,
|
w1,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
@@ -2056,7 +2222,7 @@ def fused_experts_impl(
|
|||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
intermediate_cache3.zero_()
|
intermediate_cache3.zero_()
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
dispatch_fused_moe_kernel(
|
||||||
qintermediate_cache2,
|
qintermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
@@ -2207,13 +2373,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||||
)
|
)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
invoke_fused_moe_triton_kernel(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
a1q_scale,
|
a1q_scale,
|
||||||
self.w1_scale,
|
self.w1_scale,
|
||||||
self.w1_zp,
|
|
||||||
None, # topk_weights
|
None, # topk_weights
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
@@ -2245,13 +2410,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.block_shape,
|
self.block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
invoke_fused_moe_triton_kernel(
|
||||||
qintermediate_cache2,
|
qintermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
a2q_scale,
|
a2q_scale,
|
||||||
self.w2_scale,
|
self.w2_scale,
|
||||||
self.w2_zp,
|
|
||||||
topk_weights,
|
topk_weights,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user