Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -229,6 +229,9 @@ class FusedMoEQuantConfig:
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
is_nvfp4_scale_swizzled: bool = True
|
||||
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
|
||||
hidden_pad: int = 0
|
||||
intermediate_pad: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
|
||||
@@ -257,7 +257,7 @@ def triton_kernel_moe_forward(
|
||||
# sparse_logits.indx contains global expert IDs – remap to local.
|
||||
topk_ids = expert_map[sparse_logits.indx.to(torch.long)]
|
||||
topk_weights = sparse_logits.vals
|
||||
local_num_experts = w1.size(0)
|
||||
local_num_experts = w1.shape[0]
|
||||
routing_data, gather_idx, scatter_idx = make_routing_data(
|
||||
topk_ids, topk_weights, local_num_experts
|
||||
)
|
||||
@@ -604,8 +604,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, _, N = w1.size()
|
||||
assert len(w1.shape) == 3 and len(w2.shape) == 3
|
||||
E, _, N = w1.shape
|
||||
K = a1.size(-1)
|
||||
|
||||
assert a1.dim() == 2
|
||||
@@ -683,7 +683,7 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
local_num_experts = w1.shape[0]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
@@ -781,7 +781,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
local_num_experts = w1.shape[0]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
|
||||
@@ -567,6 +567,13 @@ class FusedMoE(CustomOp):
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
self.quant_method: FusedMoEMethodBase = _get_quant_method()
|
||||
|
||||
# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
|
||||
# and intermediate_size in moe_config during __init__. Sync
|
||||
# self.hidden_size so downstream consumers (e.g. LoRA) see the
|
||||
# padded value.
|
||||
if self.moe_config.hidden_dim != self.hidden_size:
|
||||
self.hidden_size = self.moe_config.hidden_dim
|
||||
|
||||
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
|
||||
@@ -586,7 +593,7 @@ class FusedMoE(CustomOp):
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"hidden_size": self.hidden_size,
|
||||
"unpadded_hidden_size": unpadded_hidden_size,
|
||||
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
|
||||
@@ -768,8 +768,8 @@ class FusedMoEExpertsModular(FusedMoEExperts):
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
assert len(w1.shape) == 3 and len(w2.shape) == 3
|
||||
E, N, _ = w1.shape
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
@@ -1349,7 +1349,7 @@ class FusedMoEKernelModularImpl:
|
||||
else:
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
local_num_experts = w1.shape[0]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
|
||||
@@ -212,7 +212,11 @@ def select_mxfp4_moe_backend(
|
||||
# LoRA: separate experts backend path
|
||||
if config.is_lora_enabled:
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError("Mxfp4 LoRA only supported on CUDA Platform.")
|
||||
# ROCm: Triton mxfp4 LoRA hits GPU memory faults due to
|
||||
# triton_kernels.tensor.Tensor / HIP read-only page issues
|
||||
# during weight swizzle and LoRA forward. Needs work from
|
||||
# the triton_kernels/aiter side.
|
||||
raise NotImplementedError("Mxfp4 LoRA is currently only supported on CUDA.")
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
|
||||
logger.info_once("Using Triton backend for mxfp4 lora")
|
||||
return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls(
|
||||
@@ -775,6 +779,8 @@ def make_mxfp4_moe_quant_config(
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
hidden_pad: int = 0,
|
||||
intermediate_pad: int = 0,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
|
||||
if mxfp4_backend in (
|
||||
@@ -796,12 +802,16 @@ def make_mxfp4_moe_quant_config(
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.CK,
|
||||
):
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
if mxfp4_backend == Mxfp4MoeBackend.CK:
|
||||
config.hidden_pad = hidden_pad
|
||||
config.intermediate_pad = intermediate_pad
|
||||
return config
|
||||
else:
|
||||
return ocp_mx_moe_quant_config(
|
||||
quant_dtype="mxfp4",
|
||||
|
||||
@@ -292,6 +292,8 @@ def rocm_aiter_fused_experts(
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
num_local_tokens=num_local_tokens,
|
||||
output_dtype=output_dtype,
|
||||
hidden_pad=quant_config.hidden_pad,
|
||||
intermediate_pad=quant_config.intermediate_pad,
|
||||
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
|
||||
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
|
||||
)
|
||||
@@ -332,7 +334,15 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
|
||||
(kMxfp4Static, None),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
if (weight_key, activation_key) not in SUPPORTED_W_A:
|
||||
return False
|
||||
# CK MXFP4 MoE kernels are only supported on gfx950.
|
||||
if weight_key == kMxfp4Static:
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
|
||||
if not on_gfx950():
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
|
||||
@@ -158,6 +158,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition_after_pad
|
||||
)
|
||||
|
||||
# CK (gfx950) padding info for rocm_aiter_ops.fused_moe()
|
||||
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
|
||||
self.intermediate_pad = (
|
||||
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
|
||||
)
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
@@ -362,6 +368,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
hidden_pad=self.hidden_pad,
|
||||
intermediate_pad=self.intermediate_pad,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
|
||||
Reference in New Issue
Block a user