[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -1083,13 +1083,16 @@ class FusedMoEConfig:
|
||||
router_logits_dtype: torch.dtype | None = None
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
is_act_and_mul: bool = True
|
||||
|
||||
is_lora_enabled: bool = False
|
||||
|
||||
# This flag is used to disable the inplace optimization
|
||||
# in MoE kernels. If this flag is True then the kernel
|
||||
# should not be using inplace. If the flag is false, the
|
||||
# kernel is free to use inplace or not.
|
||||
disable_inplace: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug_once(
|
||||
|
||||
@@ -1165,6 +1165,7 @@ def cutlass_moe_w4a8_fp8(
|
||||
quant_config=quant_config,
|
||||
group_size=group_size,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fn(
|
||||
|
||||
@@ -267,6 +267,7 @@ def fused_marlin_moe(
|
||||
|
||||
if inplace:
|
||||
assert output is None, "Conflicting request"
|
||||
assert not disable_inplace()
|
||||
|
||||
quant_type = ScalarType.from_id(quant_type_id)
|
||||
assert quant_type in [
|
||||
@@ -356,10 +357,7 @@ def fused_marlin_moe(
|
||||
).view(-1, topk, K)
|
||||
|
||||
if output is None:
|
||||
if inplace and not disable_inplace():
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.empty_like(hidden_states)
|
||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
|
||||
if moe_sum is None:
|
||||
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
|
||||
|
||||
@@ -27,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
@@ -1511,7 +1508,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
|
||||
|
||||
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
if inplace and not disable_inplace():
|
||||
if inplace:
|
||||
return torch_vllm_inplace_fused_experts
|
||||
return torch_vllm_outplace_fused_experts
|
||||
|
||||
@@ -1534,6 +1531,8 @@ def fused_experts(
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
assert not inplace or not disable_inplace()
|
||||
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@@ -1593,7 +1592,7 @@ def fused_experts_impl(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
inplace: bool,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
@@ -1712,10 +1711,7 @@ def fused_experts_impl(
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
if inplace and not disable_inplace():
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
|
||||
if ocp_mx_scheme is not None:
|
||||
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
||||
@@ -2291,15 +2287,3 @@ class TritonWNA16Experts(TritonExperts):
|
||||
|
||||
# separate function is required for MoE + LoRA
|
||||
self.moe_sum(intermediate_cache3, output)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(moe_config, quant_config),
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
@@ -113,10 +113,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def method_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@@ -46,6 +46,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
old_quant_method: FusedMoEMethodBase,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
inplace: bool = False,
|
||||
) -> "FusedMoEModularMethod":
|
||||
return FusedMoEModularMethod(
|
||||
old_quant_method,
|
||||
@@ -54,6 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||
shared_experts,
|
||||
moe_parallel_config=moe_layer.moe_parallel_config,
|
||||
inplace=inplace,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -61,10 +63,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def supports_eplb(self) -> bool:
|
||||
return self.old_quant_method.supports_eplb
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return self.old_quant_method.allow_inplace
|
||||
|
||||
@property
|
||||
def method_name(self) -> str:
|
||||
return self.old_quant_method.method_name
|
||||
@@ -99,7 +97,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.allow_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
|
||||
@@ -50,6 +50,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
@@ -560,6 +563,8 @@ class FusedMoE(CustomOp):
|
||||
activation=activation,
|
||||
device=vllm_config.device_config.device,
|
||||
routing_method=self.routing_method_type,
|
||||
# TODO: in_dtype == out_dtype?
|
||||
disable_inplace=disable_inplace() or self.shared_experts is not None,
|
||||
)
|
||||
if self.use_mori_kernels:
|
||||
assert self.rocm_aiter_fmoe_enabled, (
|
||||
@@ -650,7 +655,11 @@ class FusedMoE(CustomOp):
|
||||
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
|
||||
)
|
||||
self.quant_method = FusedMoEModularMethod.make(
|
||||
self, self.quant_method, prepare_finalize, self.shared_experts
|
||||
self,
|
||||
self.quant_method,
|
||||
prepare_finalize,
|
||||
self.shared_experts,
|
||||
inplace=not self.moe_config.disable_inplace,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.inplace = inplace
|
||||
|
||||
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||
# FusedMoE layers / tests).
|
||||
@@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
@@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
if inplace and self.shared_experts is None and not disable_inplace():
|
||||
if self.inplace:
|
||||
assert self.shared_experts is None
|
||||
assert not disable_inplace()
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
@@ -472,7 +472,7 @@ def make_fp8_moe_kernel(
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> tuple[mk.FusedMoEModularKernel, bool]:
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
@@ -512,8 +512,10 @@ def make_fp8_moe_kernel(
|
||||
else None
|
||||
),
|
||||
moe_parallel_config=moe_config.moe_parallel_config,
|
||||
inplace=(
|
||||
not moe_config.disable_inplace
|
||||
and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
),
|
||||
)
|
||||
|
||||
# TODO(rob): update inplace logic to be part of the kernel.
|
||||
inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
return kernel, inplace
|
||||
return kernel
|
||||
|
||||
@@ -437,6 +437,7 @@ def make_nvfp4_moe_kernel(
|
||||
else None
|
||||
),
|
||||
moe_parallel_config=moe_config.moe_parallel_config,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# TODO(rob): update inplace logic to be part of the kernel.
|
||||
|
||||
@@ -154,11 +154,9 @@ def make_unquantized_moe_kernel(
|
||||
backend: UnquantizedMoeBackend,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> tuple[mk.FusedMoEModularKernel | None, bool]:
|
||||
use_inplace = True
|
||||
|
||||
) -> mk.FusedMoEModularKernel | None:
|
||||
if backend in UNSUPPORTED_BACKEND:
|
||||
return None, use_inplace
|
||||
return None
|
||||
|
||||
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
@@ -171,8 +169,9 @@ def make_unquantized_moe_kernel(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
use_inplace = False
|
||||
|
||||
elif backend == UnquantizedMoeBackend.AITER:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
@@ -184,6 +183,7 @@ def make_unquantized_moe_kernel(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
@@ -194,6 +194,7 @@ def make_unquantized_moe_kernel(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe import XPUExperts
|
||||
@@ -204,5 +205,6 @@ def make_unquantized_moe_kernel(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
return kernel, use_inplace
|
||||
return kernel
|
||||
|
||||
@@ -101,10 +101,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
@@ -225,7 +221,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
self.kernel, self.use_inplace = make_unquantized_moe_kernel(
|
||||
self.kernel = make_unquantized_moe_kernel(
|
||||
backend=self.unquantized_backend,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
@@ -329,7 +325,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
|
||||
@@ -785,4 +785,5 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace,
|
||||
input_dtype=self.input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
|
||||
@@ -515,7 +515,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
|
||||
@@ -357,7 +357,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -669,7 +668,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -960,7 +958,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -1073,7 +1071,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
# TODO(rob): investigate the disable_expert_map introduced by:
|
||||
@@ -1212,7 +1209,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
@@ -1739,6 +1736,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
workspace=layer.workspace,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
is_k_full=self.is_k_full,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
|
||||
|
||||
@@ -1969,7 +1967,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight_packed,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
@@ -2605,6 +2603,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
s_strides1=self.s_strides1,
|
||||
s_strides2=self.s_strides2,
|
||||
group_size=self.group_size,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -149,7 +149,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
|
||||
@@ -854,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -958,10 +958,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
@@ -1032,7 +1028,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
@@ -924,4 +924,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
|
||||
@@ -853,7 +853,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -967,7 +967,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -1538,7 +1537,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
@@ -378,7 +378,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
@@ -881,10 +881,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
|
||||
)
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return (
|
||||
@@ -923,6 +919,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation=layer.activation,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
|
||||
@@ -388,6 +388,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@@ -398,7 +399,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
@@ -734,10 +735,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -769,7 +766,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
|
||||
Reference in New Issue
Block a user