[Moe Refactor] Make Inplace Flag for FusedMoEModularKernel part of the constructor (#33375)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -620,6 +620,7 @@ def make_modular_kernel(
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
@@ -74,7 +74,11 @@ def test_batched_deepgemm_vs_triton(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
||||
mk_triton = FusedMoEModularKernel(
|
||||
prep_finalize,
|
||||
triton_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
out_triton = mk_triton(
|
||||
hidden_states=a,
|
||||
@@ -82,7 +86,6 @@ def test_batched_deepgemm_vs_triton(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
@@ -93,7 +96,11 @@ def test_batched_deepgemm_vs_triton(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
||||
mk_deepgemm = FusedMoEModularKernel(
|
||||
prep_finalize,
|
||||
deepgemm_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
out_deepgemm = mk_deepgemm(
|
||||
hidden_states=a,
|
||||
@@ -101,7 +108,6 @@ def test_batched_deepgemm_vs_triton(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from tests.kernels.moe.utils import (
|
||||
make_dummy_moe_config,
|
||||
make_test_quant_config,
|
||||
make_test_weights,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
@@ -26,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
@@ -261,6 +259,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
|
||||
|
||||
@@ -207,6 +207,7 @@ def run_with_expert_maps(
|
||||
),
|
||||
quant_config=new_quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
out_tensor = out_tensor + kernel(**kwargs)
|
||||
|
||||
@@ -266,6 +267,7 @@ def run_8_bit(
|
||||
),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
return kernel(**kwargs)
|
||||
|
||||
|
||||
@@ -194,8 +194,11 @@ def make_ll_modular_kernel(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
return FusedMoEModularKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
|
||||
def make_ht_modular_kernel(
|
||||
@@ -224,8 +227,11 @@ def make_ht_modular_kernel(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
return mk
|
||||
return FusedMoEModularKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
@@ -318,7 +324,6 @@ def deepep_deepgemm_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
|
||||
@@ -179,7 +179,11 @@ def make_modular_kernel(
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
|
||||
mk = FusedMoEModularKernel(
|
||||
prepare_finalize=a2a,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
return mk
|
||||
|
||||
|
||||
@@ -256,7 +260,6 @@ def deep_ep_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights_chunk,
|
||||
topk_ids=topk_chunk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
|
||||
@@ -115,6 +115,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# triton reference
|
||||
@@ -135,7 +136,6 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
@@ -301,6 +301,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
flashinfer_cutlass_output = kernel(
|
||||
@@ -309,7 +310,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
td.layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
|
||||
@@ -108,6 +108,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
|
||||
|
||||
@@ -180,7 +180,11 @@ def oai_triton_moe_impl(
|
||||
else:
|
||||
fused_experts = OAITritonExperts(make_dummy_moe_config(), quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(MoEPrepareAndFinalizeNoEP(), fused_experts)
|
||||
mk = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
fused_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return mk.forward(
|
||||
hidden_states=x,
|
||||
@@ -188,7 +192,6 @@ def oai_triton_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation="swigluoai",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
|
||||
@@ -18,7 +18,11 @@ from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe, make_dummy_moe_config
|
||||
from tests.kernels.moe.utils import (
|
||||
fused_moe,
|
||||
make_dummy_moe_config,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -36,9 +40,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
batched_fused_marlin_moe,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_permute_bias,
|
||||
)
|
||||
|
||||
@@ -95,6 +95,7 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
cutlass_output = kernel(
|
||||
|
||||
@@ -172,6 +172,7 @@ def pplx_cutlass_moe(
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
|
||||
@@ -592,6 +592,7 @@ def pplx_moe(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
|
||||
@@ -7,7 +7,11 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
TritonExperts,
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -20,6 +24,9 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
from vllm.utils.math_utils import round_up
|
||||
@@ -116,6 +123,7 @@ def batched_moe(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
@@ -157,6 +165,7 @@ def naive_batched_moe(
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
@@ -554,3 +563,16 @@ def make_shared_experts(
|
||||
return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> FusedMoEModularKernel:
|
||||
return FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(moe_config, quant_config),
|
||||
shared_experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
@@ -1083,13 +1083,16 @@ class FusedMoEConfig:
|
||||
router_logits_dtype: torch.dtype | None = None
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
is_act_and_mul: bool = True
|
||||
|
||||
is_lora_enabled: bool = False
|
||||
|
||||
# This flag is used to disable the inplace optimization
|
||||
# in MoE kernels. If this flag is True then the kernel
|
||||
# should not be using inplace. If the flag is false, the
|
||||
# kernel is free to use inplace or not.
|
||||
disable_inplace: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug_once(
|
||||
|
||||
@@ -1165,6 +1165,7 @@ def cutlass_moe_w4a8_fp8(
|
||||
quant_config=quant_config,
|
||||
group_size=group_size,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
return fn(
|
||||
|
||||
@@ -267,6 +267,7 @@ def fused_marlin_moe(
|
||||
|
||||
if inplace:
|
||||
assert output is None, "Conflicting request"
|
||||
assert not disable_inplace()
|
||||
|
||||
quant_type = ScalarType.from_id(quant_type_id)
|
||||
assert quant_type in [
|
||||
@@ -356,10 +357,7 @@ def fused_marlin_moe(
|
||||
).view(-1, topk, K)
|
||||
|
||||
if output is None:
|
||||
if inplace and not disable_inplace():
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.empty_like(hidden_states)
|
||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
|
||||
if moe_sum is None:
|
||||
return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
|
||||
|
||||
@@ -27,9 +27,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
@@ -1511,7 +1508,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
||||
|
||||
|
||||
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
if inplace and not disable_inplace():
|
||||
if inplace:
|
||||
return torch_vllm_inplace_fused_experts
|
||||
return torch_vllm_outplace_fused_experts
|
||||
|
||||
@@ -1534,6 +1531,8 @@ def fused_experts(
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
assert not inplace or not disable_inplace()
|
||||
|
||||
return dispatch_fused_experts_func(inplace)(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@@ -1593,7 +1592,7 @@ def fused_experts_impl(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
inplace: bool,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
@@ -1712,10 +1711,7 @@ def fused_experts_impl(
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
if inplace and not disable_inplace():
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
|
||||
if ocp_mx_scheme is not None:
|
||||
# TODO: On platforms for which `current_platform.supports_mx()` is True
|
||||
@@ -2291,15 +2287,3 @@ class TritonWNA16Experts(TritonExperts):
|
||||
|
||||
# separate function is required for MoE + LoRA
|
||||
self.moe_sum(intermediate_cache3, output)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(moe_config, quant_config),
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
@@ -113,10 +113,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def method_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@@ -46,6 +46,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
old_quant_method: FusedMoEMethodBase,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
shared_experts: torch.nn.Module | None,
|
||||
inplace: bool = False,
|
||||
) -> "FusedMoEModularMethod":
|
||||
return FusedMoEModularMethod(
|
||||
old_quant_method,
|
||||
@@ -54,6 +55,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
|
||||
shared_experts,
|
||||
moe_parallel_config=moe_layer.moe_parallel_config,
|
||||
inplace=inplace,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -61,10 +63,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
def supports_eplb(self) -> bool:
|
||||
return self.old_quant_method.supports_eplb
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return self.old_quant_method.allow_inplace
|
||||
|
||||
@property
|
||||
def method_name(self) -> str:
|
||||
return self.old_quant_method.method_name
|
||||
@@ -99,7 +97,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.allow_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
|
||||
@@ -50,6 +50,9 @@ from vllm.model_executor.layers.fused_moe.router.router_factory import (
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
@@ -560,6 +563,8 @@ class FusedMoE(CustomOp):
|
||||
activation=activation,
|
||||
device=vllm_config.device_config.device,
|
||||
routing_method=self.routing_method_type,
|
||||
# TODO: in_dtype == out_dtype?
|
||||
disable_inplace=disable_inplace() or self.shared_experts is not None,
|
||||
)
|
||||
if self.use_mori_kernels:
|
||||
assert self.rocm_aiter_fmoe_enabled, (
|
||||
@@ -650,7 +655,11 @@ class FusedMoE(CustomOp):
|
||||
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
|
||||
)
|
||||
self.quant_method = FusedMoEModularMethod.make(
|
||||
self, self.quant_method, prepare_finalize, self.shared_experts
|
||||
self,
|
||||
self.quant_method,
|
||||
prepare_finalize,
|
||||
self.shared_experts,
|
||||
inplace=not self.moe_config.disable_inplace,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -811,11 +811,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
fused_experts: FusedMoEPermuteExpertsUnpermute,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
moe_parallel_config: FusedMoEParallelConfig | None = None,
|
||||
inplace: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
self.shared_experts = shared_experts
|
||||
self.inplace = inplace
|
||||
|
||||
# prefer an explicit FusedMoEParallelConfig when available (from
|
||||
# FusedMoE layers / tests).
|
||||
@@ -1292,7 +1294,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -1309,8 +1310,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- topk_weights (torch.Tensor): The topk weights applied at the end of
|
||||
the layer.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
@@ -1326,7 +1325,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
if inplace and self.shared_experts is None and not disable_inplace():
|
||||
if self.inplace:
|
||||
assert self.shared_experts is None
|
||||
assert not disable_inplace()
|
||||
output = hidden_states
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
@@ -472,7 +472,7 @@ def make_fp8_moe_kernel(
|
||||
fp8_backend: Fp8MoeBackend,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
shared_experts: torch.nn.Module | None = None,
|
||||
) -> tuple[mk.FusedMoEModularKernel, bool]:
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
@@ -512,8 +512,10 @@ def make_fp8_moe_kernel(
|
||||
else None
|
||||
),
|
||||
moe_parallel_config=moe_config.moe_parallel_config,
|
||||
inplace=(
|
||||
not moe_config.disable_inplace
|
||||
and fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
),
|
||||
)
|
||||
|
||||
# TODO(rob): update inplace logic to be part of the kernel.
|
||||
inplace = fp8_backend != Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
return kernel, inplace
|
||||
return kernel
|
||||
|
||||
@@ -437,6 +437,7 @@ def make_nvfp4_moe_kernel(
|
||||
else None
|
||||
),
|
||||
moe_parallel_config=moe_config.moe_parallel_config,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# TODO(rob): update inplace logic to be part of the kernel.
|
||||
|
||||
@@ -154,11 +154,9 @@ def make_unquantized_moe_kernel(
|
||||
backend: UnquantizedMoeBackend,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
) -> tuple[mk.FusedMoEModularKernel | None, bool]:
|
||||
use_inplace = True
|
||||
|
||||
) -> mk.FusedMoEModularKernel | None:
|
||||
if backend in UNSUPPORTED_BACKEND:
|
||||
return None, use_inplace
|
||||
return None
|
||||
|
||||
if backend == UnquantizedMoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
@@ -171,8 +169,9 @@ def make_unquantized_moe_kernel(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=False,
|
||||
)
|
||||
use_inplace = False
|
||||
|
||||
elif backend == UnquantizedMoeBackend.AITER:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
@@ -184,6 +183,7 @@ def make_unquantized_moe_kernel(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.TRITON:
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
@@ -194,6 +194,7 @@ def make_unquantized_moe_kernel(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
elif backend == UnquantizedMoeBackend.XPU:
|
||||
from vllm.model_executor.layers.fused_moe import XPUExperts
|
||||
@@ -204,5 +205,6 @@ def make_unquantized_moe_kernel(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
inplace=not moe_config.disable_inplace,
|
||||
)
|
||||
return kernel, use_inplace
|
||||
return kernel
|
||||
|
||||
@@ -101,10 +101,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
@@ -225,7 +221,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
self.kernel, self.use_inplace = make_unquantized_moe_kernel(
|
||||
self.kernel = make_unquantized_moe_kernel(
|
||||
backend=self.unquantized_backend,
|
||||
quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
@@ -329,7 +325,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
|
||||
@@ -785,4 +785,5 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
w2_zeros=layer.w2_qzeros,
|
||||
workspace=layer.workspace,
|
||||
input_dtype=self.input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
|
||||
@@ -515,7 +515,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
|
||||
@@ -357,7 +357,6 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -669,7 +668,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -960,7 +958,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -1073,7 +1071,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
# TODO(rob): investigate the disable_expert_map introduced by:
|
||||
@@ -1212,7 +1209,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
@@ -1739,6 +1736,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
workspace=layer.workspace,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
is_k_full=self.is_k_full,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
|
||||
|
||||
@@ -1969,7 +1967,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight_packed,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
@@ -2605,6 +2603,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
s_strides1=self.s_strides1,
|
||||
s_strides2=self.s_strides2,
|
||||
group_size=self.group_size,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -149,7 +149,7 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
|
||||
@@ -854,7 +854,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -958,10 +958,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
@@ -1032,7 +1028,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
@@ -924,4 +924,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
|
||||
@@ -853,7 +853,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk, self.use_inplace = make_fp8_moe_kernel(
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
@@ -967,7 +967,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -1538,7 +1537,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
@@ -378,7 +378,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
|
||||
@@ -881,10 +881,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
|
||||
)
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return (
|
||||
@@ -923,6 +919,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation=layer.activation,
|
||||
expert_map=layer.expert_map,
|
||||
input_dtype=self.marlin_input_dtype,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
|
||||
@@ -388,6 +388,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
@@ -398,7 +399,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
@@ -734,10 +735,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -769,7 +766,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=not self.moe.disable_inplace,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
|
||||
Reference in New Issue
Block a user