[Feature]: Remove Chunking From FusedMoE (#34086)

Signed-off-by: SouthWest7 <am1ao@qq.com>
Signed-off-by: Southwest <1403572259@qq.com>
Signed-off-by: southwest <am1ao@qq.com>
Signed-off-by: Xinan Miao <1403572259@qq.com>
Co-authored-by: SouthWest7 <am1ao@qq.com>
This commit is contained in:
Xinan Miao
2026-03-13 02:24:38 +08:00
committed by GitHub
parent c973ecdead
commit 2cdf92228c
28 changed files with 152 additions and 523 deletions

View File

@@ -82,11 +82,6 @@ def make_config_arg_parser(description: str):
"--num-experts", type=int, default=32, help="Global num experts"
)
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
parser.add_argument(
"--fused-moe-chunk-size",
type=int,
help="Fused moe chunk size used for the non-batched fused experts impl.",
)
# Quant args
parser.add_argument(
@@ -158,7 +153,6 @@ def make_config(args: argparse.Namespace) -> Config:
quant_config=quant_config,
prepare_finalize_type=args.pf_type,
fused_experts_type=args.experts_type,
fused_moe_chunk_size=args.fused_moe_chunk_size,
world_size=args.world_size,
torch_trace_dir_path=args.torch_trace_dir_path,
)

View File

@@ -68,7 +68,6 @@ class Config:
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEExperts
fused_moe_chunk_size: int | None
world_size: int
torch_trace_dir_path: str | None = None
@@ -89,7 +88,6 @@ class Config:
s += f" K={self.K}\n"
s += f" topk={self.topks}\n"
s += f" dtype={self.dtype}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
s += " Quant:\n"
if self.quant_config is not None:
s += f" q_dtype={self.quant_dtype}\n"
@@ -152,11 +150,6 @@ class Config:
vllm_config.parallel_config.all2all_backend = self.all2all_backend()
if self.fused_moe_chunk_size is not None:
env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
)
return vllm_config, env_dict
def is_fp8_block_quantized(self):
@@ -189,10 +182,6 @@ class Config:
info = expert_info(self.fused_experts_type)
return info.blocked_quantization_support
def is_fe_supports_chunking(self):
info = expert_info(self.fused_experts_type)
return info.supports_chunking
def supports_expert_map(self):
info = expert_info(self.fused_experts_type)
return info.supports_expert_map
@@ -233,10 +222,6 @@ class Config:
if not self.is_standard_fused_experts():
return False, "Mismatched format."
use_chunking = self.fused_moe_chunk_size is not None
if use_chunking and not self.is_fe_supports_chunking():
return False, "Chunking not supported."
# Check quantization sanity
if (
int(self.is_per_act_token_quant)

View File

@@ -42,12 +42,6 @@ def rank_worker(
):
set_random_seed(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device
weights.to_current_device()
@@ -135,7 +129,6 @@ def make_feature_matrix(csv_file_path: str):
fused_experts_type=experts_type,
quant_config=quant_config,
world_size=2,
fused_moe_chunk_size=None,
)
success = None

View File

@@ -64,7 +64,6 @@ class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[torch.dtype | str]
blocked_quantization_support: bool
supports_chunking: bool
supports_expert_map: bool
needs_matching_quant: bool = False
needs_deep_gemm: bool = False
@@ -127,7 +126,6 @@ def register_experts(
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[torch.dtype | str],
blocked_quantization_support: bool,
supports_chunking: bool,
supports_expert_map: bool,
needs_matching_quant: bool = False,
needs_deep_gemm: bool = False,
@@ -141,7 +139,6 @@ def register_experts(
activation_format,
supported_dtypes,
blocked_quantization_support,
supports_chunking,
supports_expert_map,
needs_matching_quant,
needs_deep_gemm,
@@ -176,7 +173,6 @@ register_experts(
batched_format,
common_float_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
)
@@ -186,7 +182,6 @@ register_experts(
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
)
@@ -196,7 +191,6 @@ register_experts(
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=True,
)
@@ -262,7 +256,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
standard_format,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
supports_expert_map=True,
)
@@ -281,7 +274,6 @@ if has_aiter():
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_aiter=True,
)
@@ -294,7 +286,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
batched_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=False,
needs_deep_gemm=True,
@@ -304,7 +295,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
@@ -314,7 +304,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
needs_deep_gemm=True,
@@ -331,7 +320,6 @@ if cutlass_fp8_supported():
standard_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=True,
supports_expert_map=False,
)
register_experts(
@@ -339,7 +327,6 @@ if cutlass_fp8_supported():
batched_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=False,
supports_expert_map=False,
)
else:
@@ -354,7 +341,6 @@ if cutlass_fp4_supported():
standard_format,
nvfp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
else:

View File

@@ -85,12 +85,6 @@ def rank_worker(
):
set_random_seed(pgi.rank)
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device
weights.to_current_device()