Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -2526,6 +2526,7 @@ steps:
|
||||
- pytest -v -s -x lora/test_llm_with_multi_loras.py
|
||||
- pytest -v -s -x lora/test_olmoe_tp.py
|
||||
- pytest -v -s -x lora/test_gptoss_tp.py
|
||||
- pytest -v -s -x lora/test_qwen35_densemoel_lora.py
|
||||
|
||||
|
||||
- label: Weight Loading Multiple GPU # 7.5m
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
@@ -69,6 +70,16 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
|
||||
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason=(
|
||||
"Mxfp4 LoRA on ROCm is blocked by a spawn compatibility issue. "
|
||||
"The fused_moe_lora Triton kernel crashes in spawned subprocesses, "
|
||||
"and vLLM forces spawn mode when HIP is initialized before "
|
||||
"multiprocessing. Fixing this requires either making the LoRA "
|
||||
"Triton kernel spawn-safe or pre-warming the kernel cache."
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
|
||||
@pytest.mark.parametrize("specialize_active_lora", [True, False])
|
||||
def test_gpt_oss_lora(
|
||||
|
||||
@@ -109,8 +109,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
else: # fall back to the default config
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_lora_config,
|
||||
w1_shape=layer.w13_weight.size(),
|
||||
w2_shape=layer.w2_weight.size(),
|
||||
w1_shape=layer.w13_weight.shape,
|
||||
w2_shape=layer.w2_weight.shape,
|
||||
rank=rank,
|
||||
top_k=top_k,
|
||||
dtype=config_dtype,
|
||||
|
||||
@@ -379,7 +379,11 @@ def _fused_moe_lora_kernel(
|
||||
)
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
# Cast operands to matching dtype for tl.dot. On ROCm, Triton's
|
||||
# compiler may infer different types for a and b when merging
|
||||
# if/else branches (TMA desc path returns fp32, tl.load returns
|
||||
# the pointer's element type).
|
||||
accumulator += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
|
||||
|
||||
@@ -229,6 +229,9 @@ class FusedMoEQuantConfig:
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
is_nvfp4_scale_swizzled: bool = True
|
||||
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
|
||||
hidden_pad: int = 0
|
||||
intermediate_pad: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
|
||||
@@ -257,7 +257,7 @@ def triton_kernel_moe_forward(
|
||||
# sparse_logits.indx contains global expert IDs – remap to local.
|
||||
topk_ids = expert_map[sparse_logits.indx.to(torch.long)]
|
||||
topk_weights = sparse_logits.vals
|
||||
local_num_experts = w1.size(0)
|
||||
local_num_experts = w1.shape[0]
|
||||
routing_data, gather_idx, scatter_idx = make_routing_data(
|
||||
topk_ids, topk_weights, local_num_experts
|
||||
)
|
||||
@@ -604,8 +604,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, _, N = w1.size()
|
||||
assert len(w1.shape) == 3 and len(w2.shape) == 3
|
||||
E, _, N = w1.shape
|
||||
K = a1.size(-1)
|
||||
|
||||
assert a1.dim() == 2
|
||||
@@ -683,7 +683,7 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
local_num_experts = w1.shape[0]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
@@ -781,7 +781,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
local_num_experts = w1.shape[0]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
|
||||
@@ -567,6 +567,13 @@ class FusedMoE(CustomOp):
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
self.quant_method: FusedMoEMethodBase = _get_quant_method()
|
||||
|
||||
# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
|
||||
# and intermediate_size in moe_config during __init__. Sync
|
||||
# self.hidden_size so downstream consumers (e.g. LoRA) see the
|
||||
# padded value.
|
||||
if self.moe_config.hidden_dim != self.hidden_size:
|
||||
self.hidden_size = self.moe_config.hidden_dim
|
||||
|
||||
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
|
||||
raise NotImplementedError(
|
||||
"is_act_and_mul=False is supported only for CUDA and ROCm for now"
|
||||
@@ -586,7 +593,7 @@ class FusedMoE(CustomOp):
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"hidden_size": self.hidden_size,
|
||||
"unpadded_hidden_size": unpadded_hidden_size,
|
||||
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
|
||||
@@ -768,8 +768,8 @@ class FusedMoEExpertsModular(FusedMoEExperts):
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, N, _ = w1.size()
|
||||
assert len(w1.shape) == 3 and len(w2.shape) == 3
|
||||
E, N, _ = w1.shape
|
||||
K = a1.size(-1)
|
||||
|
||||
if a1.dim() == 2:
|
||||
@@ -1349,7 +1349,7 @@ class FusedMoEKernelModularImpl:
|
||||
else:
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
local_num_experts = w1.size(0)
|
||||
local_num_experts = w1.shape[0]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = local_num_experts
|
||||
|
||||
|
||||
@@ -212,7 +212,11 @@ def select_mxfp4_moe_backend(
|
||||
# LoRA: separate experts backend path
|
||||
if config.is_lora_enabled:
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError("Mxfp4 LoRA only supported on CUDA Platform.")
|
||||
# ROCm: Triton mxfp4 LoRA hits GPU memory faults due to
|
||||
# triton_kernels.tensor.Tensor / HIP read-only page issues
|
||||
# during weight swizzle and LoRA forward. Needs work from
|
||||
# the triton_kernels/aiter side.
|
||||
raise NotImplementedError("Mxfp4 LoRA is currently only supported on CUDA.")
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
|
||||
logger.info_once("Using Triton backend for mxfp4 lora")
|
||||
return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls(
|
||||
@@ -775,6 +779,8 @@ def make_mxfp4_moe_quant_config(
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
hidden_pad: int = 0,
|
||||
intermediate_pad: int = 0,
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
"""Create a FusedMoEQuantConfig for the given MXFP4 backend."""
|
||||
if mxfp4_backend in (
|
||||
@@ -796,12 +802,16 @@ def make_mxfp4_moe_quant_config(
|
||||
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
|
||||
Mxfp4MoeBackend.CK,
|
||||
):
|
||||
return mxfp4_w4a16_moe_quant_config(
|
||||
config = mxfp4_w4a16_moe_quant_config(
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
if mxfp4_backend == Mxfp4MoeBackend.CK:
|
||||
config.hidden_pad = hidden_pad
|
||||
config.intermediate_pad = intermediate_pad
|
||||
return config
|
||||
else:
|
||||
return ocp_mx_moe_quant_config(
|
||||
quant_dtype="mxfp4",
|
||||
|
||||
@@ -292,6 +292,8 @@ def rocm_aiter_fused_experts(
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
num_local_tokens=num_local_tokens,
|
||||
output_dtype=output_dtype,
|
||||
hidden_pad=quant_config.hidden_pad,
|
||||
intermediate_pad=quant_config.intermediate_pad,
|
||||
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
|
||||
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
|
||||
)
|
||||
@@ -332,7 +334,15 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
(kFp8StaticChannelSym, kFp8DynamicTokenSym),
|
||||
(kMxfp4Static, None),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
if (weight_key, activation_key) not in SUPPORTED_W_A:
|
||||
return False
|
||||
# CK MXFP4 MoE kernels are only supported on gfx950.
|
||||
if weight_key == kMxfp4Static:
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
|
||||
if not on_gfx950():
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
|
||||
@@ -158,6 +158,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition_after_pad
|
||||
)
|
||||
|
||||
# CK (gfx950) padding info for rocm_aiter_ops.fused_moe()
|
||||
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
|
||||
self.intermediate_pad = (
|
||||
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
|
||||
)
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
@@ -362,6 +368,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
hidden_pad=self.hidden_pad,
|
||||
intermediate_pad=self.intermediate_pad,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
|
||||
Reference in New Issue
Block a user