[Feature]: Remove Chunking From FusedMoE (#34086)
Signed-off-by: SouthWest7 <am1ao@qq.com> Signed-off-by: Southwest <1403572259@qq.com> Signed-off-by: southwest <am1ao@qq.com> Signed-off-by: Xinan Miao <1403572259@qq.com> Co-authored-by: SouthWest7 <am1ao@qq.com>
This commit is contained in:
@@ -82,11 +82,6 @@ def make_config_arg_parser(description: str):
|
||||
"--num-experts", type=int, default=32, help="Global num experts"
|
||||
)
|
||||
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
|
||||
parser.add_argument(
|
||||
"--fused-moe-chunk-size",
|
||||
type=int,
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl.",
|
||||
)
|
||||
|
||||
# Quant args
|
||||
parser.add_argument(
|
||||
@@ -158,7 +153,6 @@ def make_config(args: argparse.Namespace) -> Config:
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=args.pf_type,
|
||||
fused_experts_type=args.experts_type,
|
||||
fused_moe_chunk_size=args.fused_moe_chunk_size,
|
||||
world_size=args.world_size,
|
||||
torch_trace_dir_path=args.torch_trace_dir_path,
|
||||
)
|
||||
|
||||
@@ -68,7 +68,6 @@ class Config:
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEExperts
|
||||
|
||||
fused_moe_chunk_size: int | None
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: str | None = None
|
||||
@@ -89,7 +88,6 @@ class Config:
|
||||
s += f" K={self.K}\n"
|
||||
s += f" topk={self.topks}\n"
|
||||
s += f" dtype={self.dtype}\n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
|
||||
s += " Quant:\n"
|
||||
if self.quant_config is not None:
|
||||
s += f" q_dtype={self.quant_dtype}\n"
|
||||
@@ -152,11 +150,6 @@ class Config:
|
||||
|
||||
vllm_config.parallel_config.all2all_backend = self.all2all_backend()
|
||||
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
|
||||
)
|
||||
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
@@ -189,10 +182,6 @@ class Config:
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.blocked_quantization_support
|
||||
|
||||
def is_fe_supports_chunking(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_chunking
|
||||
|
||||
def supports_expert_map(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_expert_map
|
||||
@@ -233,10 +222,6 @@ class Config:
|
||||
if not self.is_standard_fused_experts():
|
||||
return False, "Mismatched format."
|
||||
|
||||
use_chunking = self.fused_moe_chunk_size is not None
|
||||
if use_chunking and not self.is_fe_supports_chunking():
|
||||
return False, "Chunking not supported."
|
||||
|
||||
# Check quantization sanity
|
||||
if (
|
||||
int(self.is_per_act_token_quant)
|
||||
|
||||
@@ -42,12 +42,6 @@ def rank_worker(
|
||||
):
|
||||
set_random_seed(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
@@ -135,7 +129,6 @@ def make_feature_matrix(csv_file_path: str):
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None,
|
||||
)
|
||||
|
||||
success = None
|
||||
|
||||
@@ -64,7 +64,6 @@ class ExpertInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
supports_chunking: bool
|
||||
supports_expert_map: bool
|
||||
needs_matching_quant: bool = False
|
||||
needs_deep_gemm: bool = False
|
||||
@@ -127,7 +126,6 @@ def register_experts(
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
supports_chunking: bool,
|
||||
supports_expert_map: bool,
|
||||
needs_matching_quant: bool = False,
|
||||
needs_deep_gemm: bool = False,
|
||||
@@ -141,7 +139,6 @@ def register_experts(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
supports_chunking,
|
||||
supports_expert_map,
|
||||
needs_matching_quant,
|
||||
needs_deep_gemm,
|
||||
@@ -176,7 +173,6 @@ register_experts(
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
@@ -186,7 +182,6 @@ register_experts(
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
@@ -196,7 +191,6 @@ register_experts(
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=True,
|
||||
)
|
||||
|
||||
@@ -262,7 +256,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
standard_format,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
supports_expert_map=True,
|
||||
)
|
||||
@@ -281,7 +274,6 @@ if has_aiter():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_aiter=True,
|
||||
)
|
||||
@@ -294,7 +286,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
@@ -304,7 +295,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
@@ -314,7 +304,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
@@ -331,7 +320,6 @@ if cutlass_fp8_supported():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
register_experts(
|
||||
@@ -339,7 +327,6 @@ if cutlass_fp8_supported():
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
else:
|
||||
@@ -354,7 +341,6 @@ if cutlass_fp4_supported():
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -85,12 +85,6 @@ def rank_worker(
|
||||
):
|
||||
set_random_seed(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
|
||||
@@ -158,8 +158,6 @@ def test_w8a8_block_fp8_fused_moe(
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
@@ -226,11 +224,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
block_size = get_mk_alignment_for_contiguous_layout()
|
||||
dtype = torch.bfloat16
|
||||
|
||||
@@ -252,9 +247,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (
|
||||
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
|
||||
)
|
||||
use_cudagraph = N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
|
||||
@@ -321,7 +321,6 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
ep_size: int | None = None,
|
||||
):
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
@@ -376,7 +375,6 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
|
||||
@@ -204,7 +204,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("Test is only supported for sm >= 100")
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, is_trtllm=True, activation=activation
|
||||
@@ -289,7 +288,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, is_trtllm=False, activation=activation
|
||||
|
||||
@@ -84,12 +84,6 @@ def rank_worker(
|
||||
|
||||
set_random_seed(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if base_config.fused_moe_chunk_size is not None:
|
||||
assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
@@ -162,7 +156,6 @@ Ns = [1024]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
FUSED_MOE_CHUNK_SIZES = [None, 16]
|
||||
|
||||
|
||||
def is_nyi_config(config: Config) -> bool:
|
||||
@@ -185,14 +178,13 @@ def generate_valid_test_cases(
|
||||
cases = []
|
||||
total = 0
|
||||
|
||||
for k, n, e, dtype, quant_config, combination, chunk_size in product(
|
||||
for k, n, e, dtype, quant_config, combination in product(
|
||||
Ks,
|
||||
Ns,
|
||||
Es,
|
||||
DTYPEs,
|
||||
MK_QUANT_CONFIGS,
|
||||
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
|
||||
FUSED_MOE_CHUNK_SIZES,
|
||||
):
|
||||
total = total + 1
|
||||
|
||||
@@ -206,7 +198,6 @@ def generate_valid_test_cases(
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@@ -234,7 +225,6 @@ def generate_valid_test_cases(
|
||||
quant_config,
|
||||
combination[0],
|
||||
combination[1],
|
||||
chunk_size,
|
||||
world_size,
|
||||
)
|
||||
)
|
||||
@@ -245,7 +235,7 @@ def generate_valid_test_cases(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
@@ -259,7 +249,6 @@ def test_modular_kernel_combinations_multigpu(
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEExperts,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
@@ -280,7 +269,6 @@ def test_modular_kernel_combinations_multigpu(
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
@@ -288,7 +276,7 @@ def test_modular_kernel_combinations_multigpu(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
@@ -301,7 +289,6 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEExperts,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
workspace_init,
|
||||
@@ -318,7 +305,6 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -287,7 +287,6 @@ def run_moe_test(
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize("chunk_size", [8192])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -297,14 +296,11 @@ def test_fused_moe(
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
chunk_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
#
|
||||
# Setup test data
|
||||
#
|
||||
@@ -398,12 +394,12 @@ def test_fused_moe(
|
||||
)
|
||||
|
||||
|
||||
def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
|
||||
def test_fused_moe_int64_overflow(workspace_init):
|
||||
"""Regression test for int32 overflow in stride*offset products.
|
||||
|
||||
When chunking is disabled and M is large, stride_cm * offs_token can
|
||||
exceed int32 max. Verifies the offs_token int64 cast (fix for #34413)
|
||||
prevents overflow and produces correct results.
|
||||
With large M, stride_cm * offs_token can exceed int32 max. Verifies
|
||||
the offs_token int64 cast (fix for #34413) prevents overflow and
|
||||
produces correct results.
|
||||
|
||||
Reproduces the scenario from PR #34279.
|
||||
"""
|
||||
@@ -417,9 +413,6 @@ def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
|
||||
m, n, k, e, topk = 100000, 2048, 1024, 8, 6
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Disable chunking to expose the overflow-prone code path
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "10000000")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
@@ -452,7 +445,6 @@ def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
|
||||
@pytest.mark.parametrize("topk", TOP_KS_SMALL)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize("chunk_size", [8192])
|
||||
def test_naive_block_assignment_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -461,14 +453,11 @@ def test_naive_block_assignment_moe(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
chunk_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
#
|
||||
# Setup test data
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user