[MoE Refactoring][Bugfix]Wrap WNA16 Triton kernel into mk and change compressed tensor kernel selection (#31752)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Yongye Zhu
2026-01-08 16:01:30 -08:00
committed by GitHub
parent 6cdf015c3c
commit d62cfe546d
3 changed files with 148 additions and 4 deletions

View File

@@ -85,6 +85,7 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
TritonExperts,
TritonWNA16Experts,
fused_experts,
fused_topk,
get_config_file_name,
@@ -103,6 +104,7 @@ if HAS_TRITON:
"CutlassBatchedExpertsFp8",
"CutlassExpertsW4A8Fp8",
"TritonExperts",
"TritonWNA16Experts",
"BatchedTritonExperts",
"DeepGemmExperts",
"BatchedDeepGemmExperts",

View File

@@ -624,11 +624,11 @@ def invoke_fused_moe_wna16_triton_kernel(
compute_type: tl.dtype,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_shape: list[int],
block_shape: list[int] | None,
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
assert block_shape is None or block_shape[0] == 0
assert block_shape is not None and block_shape[0] == 0
M = A.size(0)
num_tokens = M * top_k
@@ -2447,6 +2447,148 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
ops.moe_sum(input, output)
class TritonWNA16Experts(TritonExperts):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
)
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
]
E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)
if global_num_experts == -1:
global_num_experts = E
config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
)
if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
)
invoke_fused_moe_wna16_triton_kernel(
hidden_states,
w1,
intermediate_cache1,
self.w1_scale,
self.quant_config.w1_zp,
None, # topk_weights
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False, # mul_routed_weights
top_k_num,
config,
compute_type=compute_type,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)
self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)
a2q_scale: torch.Tensor | None = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
self.quant_dtype,
self.per_act_token_quant,
self.block_shape,
)
invoke_fused_moe_wna16_triton_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
self.w2_scale,
self.quant_config.w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)
# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)
def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:

View File

@@ -1693,11 +1693,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe import TritonWNA16Experts
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
return TritonExperts(quant_config=self.moe_quant_config)
return TritonWNA16Experts(quant_config=self.moe_quant_config)
else:
raise NotImplementedError(
"TritonExperts requires Triton. "