[Feature]: Remove Chunking From FusedMoE (#34086)
Signed-off-by: SouthWest7 <am1ao@qq.com> Signed-off-by: Southwest <1403572259@qq.com> Signed-off-by: southwest <am1ao@qq.com> Signed-off-by: Xinan Miao <1403572259@qq.com> Co-authored-by: SouthWest7 <am1ao@qq.com>
This commit is contained in:
@@ -167,9 +167,6 @@ FusedMoEExpertsModular performs the core of the FusedMoE operations. The various
|
||||
|
||||
`FusedMoEExpertsModular::activation_formats()`: Return the supported Input and Output activation formats. i.e. Contiguous / Batched format.
|
||||
|
||||
`FusedMoEExpertsModular::supports_chunking()`: Return True if the implementation supports chunking. Typically
|
||||
implementations that input `FusedMoEActivationFormat.Standard` support chunking and `FusedMoEActivationFormat.BatchedExperts` do not.
|
||||
|
||||
`FusedMoEExpertsModular::supports_expert_map()`: Return True if the implementation supports expert map.
|
||||
|
||||
`FusedMoEExpertsModular::workspace_shapes()` /
|
||||
@@ -220,8 +217,8 @@ If you are adding some `FusedMoEPrepareAndFinalizeModular` / `FusedMoEExpertsMod
|
||||
|
||||
1. Add the implementation type to `MK_ALL_PREPARE_FINALIZE_TYPES` and `MK_FUSED_EXPERT_TYPES` in [mk_objects.py](../../tests/kernels/moe/modular_kernel_tools/mk_objects.py) respectively.
|
||||
2. Update `Config::is_batched_prepare_finalize()`, `Config::is_batched_fused_experts()`, `Config::is_standard_fused_experts()`,
|
||||
`Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`,
|
||||
`Config::is_fe_supports_chunking()` methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py)
|
||||
`Config::is_fe_16bit_supported()`, `Config::is_fe_fp8_supported()`, `Config::is_fe_block_fp8_supported()`
|
||||
methods in [/tests/kernels/moe/modular_kernel_tools/common.py](../../tests/kernels/moe/modular_kernel_tools/common.py)
|
||||
|
||||
Doing this will add the new implementation to the test suite.
|
||||
|
||||
|
||||
@@ -82,11 +82,6 @@ def make_config_arg_parser(description: str):
|
||||
"--num-experts", type=int, default=32, help="Global num experts"
|
||||
)
|
||||
parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk")
|
||||
parser.add_argument(
|
||||
"--fused-moe-chunk-size",
|
||||
type=int,
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl.",
|
||||
)
|
||||
|
||||
# Quant args
|
||||
parser.add_argument(
|
||||
@@ -158,7 +153,6 @@ def make_config(args: argparse.Namespace) -> Config:
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=args.pf_type,
|
||||
fused_experts_type=args.experts_type,
|
||||
fused_moe_chunk_size=args.fused_moe_chunk_size,
|
||||
world_size=args.world_size,
|
||||
torch_trace_dir_path=args.torch_trace_dir_path,
|
||||
)
|
||||
|
||||
@@ -68,7 +68,6 @@ class Config:
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEExperts
|
||||
|
||||
fused_moe_chunk_size: int | None
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: str | None = None
|
||||
@@ -89,7 +88,6 @@ class Config:
|
||||
s += f" K={self.K}\n"
|
||||
s += f" topk={self.topks}\n"
|
||||
s += f" dtype={self.dtype}\n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
|
||||
s += " Quant:\n"
|
||||
if self.quant_config is not None:
|
||||
s += f" q_dtype={self.quant_dtype}\n"
|
||||
@@ -152,11 +150,6 @@ class Config:
|
||||
|
||||
vllm_config.parallel_config.all2all_backend = self.all2all_backend()
|
||||
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}
|
||||
)
|
||||
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
@@ -189,10 +182,6 @@ class Config:
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.blocked_quantization_support
|
||||
|
||||
def is_fe_supports_chunking(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_chunking
|
||||
|
||||
def supports_expert_map(self):
|
||||
info = expert_info(self.fused_experts_type)
|
||||
return info.supports_expert_map
|
||||
@@ -233,10 +222,6 @@ class Config:
|
||||
if not self.is_standard_fused_experts():
|
||||
return False, "Mismatched format."
|
||||
|
||||
use_chunking = self.fused_moe_chunk_size is not None
|
||||
if use_chunking and not self.is_fe_supports_chunking():
|
||||
return False, "Chunking not supported."
|
||||
|
||||
# Check quantization sanity
|
||||
if (
|
||||
int(self.is_per_act_token_quant)
|
||||
|
||||
@@ -42,12 +42,6 @@ def rank_worker(
|
||||
):
|
||||
set_random_seed(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
@@ -135,7 +129,6 @@ def make_feature_matrix(csv_file_path: str):
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None,
|
||||
)
|
||||
|
||||
success = None
|
||||
|
||||
@@ -64,7 +64,6 @@ class ExpertInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
supported_dtypes: list[torch.dtype | str]
|
||||
blocked_quantization_support: bool
|
||||
supports_chunking: bool
|
||||
supports_expert_map: bool
|
||||
needs_matching_quant: bool = False
|
||||
needs_deep_gemm: bool = False
|
||||
@@ -127,7 +126,6 @@ def register_experts(
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
supported_dtypes: list[torch.dtype | str],
|
||||
blocked_quantization_support: bool,
|
||||
supports_chunking: bool,
|
||||
supports_expert_map: bool,
|
||||
needs_matching_quant: bool = False,
|
||||
needs_deep_gemm: bool = False,
|
||||
@@ -141,7 +139,6 @@ def register_experts(
|
||||
activation_format,
|
||||
supported_dtypes,
|
||||
blocked_quantization_support,
|
||||
supports_chunking,
|
||||
supports_expert_map,
|
||||
needs_matching_quant,
|
||||
needs_deep_gemm,
|
||||
@@ -176,7 +173,6 @@ register_experts(
|
||||
batched_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
@@ -186,7 +182,6 @@ register_experts(
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
)
|
||||
@@ -196,7 +191,6 @@ register_experts(
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=True,
|
||||
)
|
||||
|
||||
@@ -262,7 +256,6 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
|
||||
standard_format,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
supports_expert_map=True,
|
||||
)
|
||||
@@ -281,7 +274,6 @@ if has_aiter():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_aiter=True,
|
||||
)
|
||||
@@ -294,7 +286,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
@@ -304,7 +295,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
@@ -314,7 +304,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
standard_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=True,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
@@ -331,7 +320,6 @@ if cutlass_fp8_supported():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
register_experts(
|
||||
@@ -339,7 +327,6 @@ if cutlass_fp8_supported():
|
||||
batched_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=False,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
else:
|
||||
@@ -354,7 +341,6 @@ if cutlass_fp4_supported():
|
||||
standard_format,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -85,12 +85,6 @@ def rank_worker(
|
||||
):
|
||||
set_random_seed(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
|
||||
@@ -158,8 +158,6 @@ def test_w8a8_block_fp8_fused_moe(
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
@@ -226,11 +224,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
block_size = get_mk_alignment_for_contiguous_layout()
|
||||
dtype = torch.bfloat16
|
||||
|
||||
@@ -252,9 +247,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (
|
||||
chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
|
||||
)
|
||||
use_cudagraph = N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
|
||||
@@ -321,7 +321,6 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
ep_size: int | None = None,
|
||||
):
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
|
||||
|
||||
@@ -376,7 +375,6 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
dtype = torch.half
|
||||
|
||||
|
||||
@@ -204,7 +204,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip("Test is only supported for sm >= 100")
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, is_trtllm=True, activation=activation
|
||||
@@ -289,7 +288,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, is_trtllm=False, activation=activation
|
||||
|
||||
@@ -84,12 +84,6 @@ def rank_worker(
|
||||
|
||||
set_random_seed(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
|
||||
if base_config.fused_moe_chunk_size is not None:
|
||||
assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
@@ -162,7 +156,6 @@ Ns = [1024]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
FUSED_MOE_CHUNK_SIZES = [None, 16]
|
||||
|
||||
|
||||
def is_nyi_config(config: Config) -> bool:
|
||||
@@ -185,14 +178,13 @@ def generate_valid_test_cases(
|
||||
cases = []
|
||||
total = 0
|
||||
|
||||
for k, n, e, dtype, quant_config, combination, chunk_size in product(
|
||||
for k, n, e, dtype, quant_config, combination in product(
|
||||
Ks,
|
||||
Ns,
|
||||
Es,
|
||||
DTYPEs,
|
||||
MK_QUANT_CONFIGS,
|
||||
product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES),
|
||||
FUSED_MOE_CHUNK_SIZES,
|
||||
):
|
||||
total = total + 1
|
||||
|
||||
@@ -206,7 +198,6 @@ def generate_valid_test_cases(
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@@ -234,7 +225,6 @@ def generate_valid_test_cases(
|
||||
quant_config,
|
||||
combination[0],
|
||||
combination[1],
|
||||
chunk_size,
|
||||
world_size,
|
||||
)
|
||||
)
|
||||
@@ -245,7 +235,7 @@ def generate_valid_test_cases(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
@@ -259,7 +249,6 @@ def test_modular_kernel_combinations_multigpu(
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEExperts,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
@@ -280,7 +269,6 @@ def test_modular_kernel_combinations_multigpu(
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
verbosity = pytestconfig.getoption("verbose")
|
||||
@@ -288,7 +276,7 @@ def test_modular_kernel_combinations_multigpu(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size",
|
||||
"k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,world_size",
|
||||
generate_valid_test_cases(
|
||||
world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||
),
|
||||
@@ -301,7 +289,6 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
quant_config: TestMoEQuantConfig | None,
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
fused_experts_type: mk.FusedMoEExperts,
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
workspace_init,
|
||||
@@ -318,7 +305,6 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=prepare_finalize_type,
|
||||
fused_experts_type=fused_experts_type,
|
||||
fused_moe_chunk_size=chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -287,7 +287,6 @@ def run_moe_test(
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize("chunk_size", [8192])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -297,14 +296,11 @@ def test_fused_moe(
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
chunk_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
#
|
||||
# Setup test data
|
||||
#
|
||||
@@ -398,12 +394,12 @@ def test_fused_moe(
|
||||
)
|
||||
|
||||
|
||||
def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
|
||||
def test_fused_moe_int64_overflow(workspace_init):
|
||||
"""Regression test for int32 overflow in stride*offset products.
|
||||
|
||||
When chunking is disabled and M is large, stride_cm * offs_token can
|
||||
exceed int32 max. Verifies the offs_token int64 cast (fix for #34413)
|
||||
prevents overflow and produces correct results.
|
||||
With large M, stride_cm * offs_token can exceed int32 max. Verifies
|
||||
the offs_token int64 cast (fix for #34413) prevents overflow and
|
||||
produces correct results.
|
||||
|
||||
Reproduces the scenario from PR #34279.
|
||||
"""
|
||||
@@ -417,9 +413,6 @@ def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
|
||||
m, n, k, e, topk = 100000, 2048, 1024, 8, 6
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Disable chunking to expose the overflow-prone code path
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "10000000")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
@@ -452,7 +445,6 @@ def test_fused_moe_int64_overflow(monkeypatch, workspace_init):
|
||||
@pytest.mark.parametrize("topk", TOP_KS_SMALL)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize("chunk_size", [8192])
|
||||
def test_naive_block_assignment_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -461,14 +453,11 @@ def test_naive_block_assignment_moe(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
chunk_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
#
|
||||
# Setup test data
|
||||
#
|
||||
|
||||
11
vllm/envs.py
11
vllm/envs.py
@@ -53,8 +53,6 @@ if TYPE_CHECKING:
|
||||
VLLM_CPU_SGL_KERNEL: bool = False
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024
|
||||
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
|
||||
@@ -822,15 +820,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
),
|
||||
# Enable SPMD mode for TPU backend.
|
||||
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
|
||||
"VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(
|
||||
os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024))
|
||||
),
|
||||
# Control whether to use fused MoE activation chunking. Current chunking
|
||||
# logic is incompatible with torch.compile and causes IMA. See issue
|
||||
# https://github.com/vllm-project/vllm/issues/19631.
|
||||
"VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": lambda: bool(
|
||||
int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1"))
|
||||
),
|
||||
# If set, the OpenAI API server will stay alive even after the underlying
|
||||
# AsyncLLMEngine errors and stops serving requests
|
||||
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(
|
||||
|
||||
@@ -190,9 +190,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
)
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
M = num_tokens
|
||||
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w13",
|
||||
@@ -281,9 +280,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
)
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
M = num_tokens
|
||||
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w2",
|
||||
|
||||
@@ -311,9 +311,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -400,9 +400,6 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
or moe_parallel_config.use_deepep_ht_kernels
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -445,9 +442,6 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -713,9 +707,6 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
@@ -998,9 +989,6 @@ class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -154,9 +154,6 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -92,16 +92,6 @@ class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
|
||||
moe_parallel_config
|
||||
) and fallback_cls._supports_parallel_config(moe_parallel_config)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
assert (
|
||||
self.experts.supports_chunking()
|
||||
== self.fallback_experts.supports_chunking()
|
||||
)
|
||||
return (
|
||||
self.experts.supports_chunking()
|
||||
and self.fallback_experts.supports_chunking()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
assert (
|
||||
self.experts.supports_expert_map()
|
||||
|
||||
@@ -83,12 +83,6 @@ class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
# This refers to TP chunking; DP chunking is handled separately.
|
||||
# TODO(shuw@nvidia.com): Set to False to be consistent with
|
||||
# batched_deep_gemm_moe
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
@@ -195,10 +195,6 @@ class FlashInferExperts(mk.FusedMoEExpertsModular):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
# This refers to TP chunking; DP chunking is handled separately.
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -712,9 +712,6 @@ class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -957,9 +954,6 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -658,9 +658,6 @@ class MarlinExperts(MarlinExpertsBase):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
@@ -786,9 +783,6 @@ class BatchedMarlinExperts(MarlinExpertsBase):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
|
||||
@@ -1693,10 +1693,8 @@ def fused_experts_impl(
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.size(1)
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
M = num_tokens
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
@@ -1787,139 +1785,114 @@ def fused_experts_impl(
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
|
||||
|
||||
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
|
||||
begin_chunk_idx, end_chunk_idx = (
|
||||
chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE, num_tokens),
|
||||
qhidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=hidden_states,
|
||||
A_scale=a1_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
# SPARSITY_FACTOR is a heuristic margin ensuring num_tokens * top_k
|
||||
# activates only a small fraction of total experts
|
||||
SPARSITY_FACTOR = 4
|
||||
# block quantized code path is not implemented yet.
|
||||
naive_block_assignment = (
|
||||
expert_map is None
|
||||
and num_tokens * top_k_num * SPARSITY_FACTOR <= global_num_experts
|
||||
and not (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
)
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.size()
|
||||
)
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
# Adjust the intermediate cache size and config for the last
|
||||
# chunk. Note that in most cases we only have one chunk
|
||||
# so the cache size and config are already set correctly and
|
||||
# do not need to be adjusted.
|
||||
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
||||
intermediate_cache2 = intermediate_cache2[
|
||||
: tokens_in_chunk * topk_ids.size(1)
|
||||
]
|
||||
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
||||
config = get_config_func(tokens_in_chunk)
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
A_scale=a1_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
if not naive_block_assignment:
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
|
||||
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
|
||||
# activates only a small fraction of total experts
|
||||
SPARSITY_FACTOR = 4
|
||||
# block quantized code path is not implemented yet.
|
||||
naive_block_assignment = (
|
||||
expert_map is None
|
||||
and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
|
||||
and not (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
)
|
||||
else:
|
||||
max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
|
||||
expert_ids = topk_ids.view(-1)
|
||||
num_tokens_post_padded = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_padded.fill_(max_num_tokens_padded)
|
||||
sorted_token_ids = None
|
||||
|
||||
if not naive_block_assignment:
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
curr_topk_ids,
|
||||
config["BLOCK_SIZE_M"],
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
else:
|
||||
max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
|
||||
expert_ids = curr_topk_ids.view(-1)
|
||||
num_tokens_post_padded = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_padded.fill_(max_num_tokens_padded)
|
||||
sorted_token_ids = None
|
||||
dispatch_fused_moe_kernel(
|
||||
qhidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
apply_router_weight_on_input,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias,
|
||||
)
|
||||
|
||||
dispatch_fused_moe_kernel(
|
||||
qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
apply_router_weight_on_input,
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w1_bias,
|
||||
)
|
||||
apply_moe_activation(
|
||||
activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
|
||||
apply_moe_activation(
|
||||
activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
ocp_mx_scheme=ocp_mx_scheme,
|
||||
)
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
dispatch_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w2_bias,
|
||||
)
|
||||
|
||||
dispatch_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
not apply_router_weight_on_input,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
B_bias=w2_bias,
|
||||
)
|
||||
|
||||
ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.size()),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.size()),
|
||||
out_hidden_states,
|
||||
)
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
@@ -1994,9 +1967,6 @@ class TritonExperts(mk.FusedMoEExpertsModular):
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -609,9 +609,6 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
@@ -696,9 +693,6 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
|
||||
@@ -9,8 +9,6 @@ from typing import final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
@@ -24,14 +22,12 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache,
|
||||
count_expert_num_tokens,
|
||||
disable_inplace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
@@ -719,15 +715,6 @@ class FusedMoEExperts(ABC):
|
||||
def g2_alphas(self) -> torch.Tensor | None:
|
||||
return self.quant_config.g2_alphas
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports activation
|
||||
chunking.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
@@ -742,11 +729,6 @@ class FusedMoEExperts(ABC):
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_chunking(self):
|
||||
return (
|
||||
envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
|
||||
)
|
||||
|
||||
|
||||
class FusedMoEExpertsModular(FusedMoEExperts):
|
||||
"""
|
||||
@@ -995,17 +977,6 @@ class FusedMoEExpertsMonolithic(FusedMoEExperts):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _slice_scales(
|
||||
scales: torch.Tensor | None, start: int, end: int
|
||||
) -> torch.Tensor | None:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Kernel
|
||||
################################################################################
|
||||
@@ -1032,26 +1003,6 @@ class FusedMoEKernelModularImpl:
|
||||
and moe_parallel_config.use_ep
|
||||
)
|
||||
|
||||
def _chunk_info(self, M: int) -> tuple[int, int]:
|
||||
"""
|
||||
Compute number of chunks and chunk size for given M.
|
||||
If chunking is not supported, set the CHUNK_SIZE to M so we
|
||||
get num_chunks == 1. Take max(M, 1) to avoid divide by zero.
|
||||
If there are no tokens to process, the number of chunks will be zero.
|
||||
"""
|
||||
CHUNK_SIZE = max(
|
||||
1,
|
||||
(
|
||||
M
|
||||
if not self.fused_experts.enable_chunking()
|
||||
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
),
|
||||
)
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
# If there are no tokens, then there should be no loop iterations.
|
||||
assert M > 0 or num_chunks == 0
|
||||
return num_chunks, CHUNK_SIZE
|
||||
|
||||
def _allocate_buffers(
|
||||
self,
|
||||
out_dtype: torch.dtype,
|
||||
@@ -1076,40 +1027,8 @@ class FusedMoEKernelModularImpl:
|
||||
"""
|
||||
assert M_full > 0 and M_chunk > 0
|
||||
|
||||
num_chunks, _ = self._chunk_info(M_full)
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
# "mk.FusedMoEKernel.Standard" formats where this is only bounded
|
||||
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
|
||||
# DP+EP due to the random token routing.
|
||||
is_profile_run = (
|
||||
is_forward_context_available()
|
||||
and get_forward_context().attn_metadata is None
|
||||
)
|
||||
if is_profile_run and self.fused_experts.enable_chunking() and self.is_dp_ep:
|
||||
max_workspace_13, max_workspace_2, max_fused_out_shape = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
envs.VLLM_FUSED_MOE_CHUNK_SIZE,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
# expert_tokens_meta help in allocating optimal/minimal
|
||||
# amount of workspace. Mark it None, so we allocate for
|
||||
# the worst-case scenario.
|
||||
expert_tokens_meta=None,
|
||||
activation=activation,
|
||||
)
|
||||
)
|
||||
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(max_workspace_13, workspace_dtype),
|
||||
(max_workspace_2, workspace_dtype),
|
||||
(max_fused_out_shape, out_dtype),
|
||||
)
|
||||
|
||||
# Get intermediate workspace shapes based off the chunked M size.
|
||||
workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes(
|
||||
M_chunk,
|
||||
@@ -1136,80 +1055,17 @@ class FusedMoEKernelModularImpl:
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
# Reuse workspace13 for the output in the non-chunked case.
|
||||
# This will not always be the case for standard
|
||||
# format experts and with experts that have empty workspaces.
|
||||
if num_chunks == 1:
|
||||
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
|
||||
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
|
||||
((max_shape_size,), workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
)
|
||||
workspace13 = _resize_cache(common_workspace, workspace13_shape)
|
||||
fused_out = _resize_cache(common_workspace, fused_out_shape)
|
||||
else:
|
||||
workspace13, workspace2, fused_out = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(workspace13_shape, workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
(fused_out_shape, out_dtype),
|
||||
)
|
||||
)
|
||||
# Reuse workspace13 for the output since there is only one chunk.
|
||||
max_shape_size = max(prod(workspace13_shape), prod(fused_out_shape))
|
||||
common_workspace, workspace2 = current_workspace_manager().get_simultaneous(
|
||||
((max_shape_size,), workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
)
|
||||
workspace13 = _resize_cache(common_workspace, workspace13_shape)
|
||||
fused_out = _resize_cache(common_workspace, fused_out_shape)
|
||||
|
||||
return workspace13, workspace2, fused_out
|
||||
|
||||
@staticmethod
|
||||
def _slice_output_tensor(
|
||||
fused_out: torch.Tensor,
|
||||
chunk_idx: int,
|
||||
num_chunks: int,
|
||||
CHUNK_SIZE: int,
|
||||
M: int,
|
||||
) -> torch.Tensor:
|
||||
if num_chunks == 1:
|
||||
return fused_out
|
||||
|
||||
assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}"
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
|
||||
@staticmethod
|
||||
def _slice_expert_tokens_metadata(
|
||||
num_chunks: int,
|
||||
full_expert_tokens_meta: ExpertTokensMetadata | None,
|
||||
chunk_topk_ids: torch.Tensor,
|
||||
local_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
) -> ExpertTokensMetadata | None:
|
||||
if num_chunks == 1 or full_expert_tokens_meta is None:
|
||||
return full_expert_tokens_meta
|
||||
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map
|
||||
)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None
|
||||
)
|
||||
if need_expert_num_tokens_cpu:
|
||||
# This is blocking as some implementations need the count
|
||||
# on the CPU to determine appropriate input/out fused-moe
|
||||
# buffers
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
|
||||
)
|
||||
|
||||
def _prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1318,18 +1174,6 @@ class FusedMoEKernelModularImpl:
|
||||
a1q, w1, w2, topk_ids
|
||||
)
|
||||
|
||||
num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
|
||||
|
||||
def input_chunk_range(chunk_idx: int) -> tuple[int, int]:
|
||||
if num_chunks == 1:
|
||||
# Use a1q.size(0) here since batched format does not
|
||||
# keep M in the first dimension.
|
||||
return 0, a1q.size(0)
|
||||
else:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M_full)
|
||||
return s, e
|
||||
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
# incompatible all2all kernels like the DeepEP high-throughput
|
||||
@@ -1337,58 +1181,39 @@ class FusedMoEKernelModularImpl:
|
||||
# low-latency kernels are always batched and can never run into
|
||||
# the tensor.numel() == 0 case.
|
||||
if M_full == 0:
|
||||
assert num_chunks == 0
|
||||
workspace13 = None
|
||||
workspace2 = None
|
||||
fused_out = torch.empty_like(a1q, dtype=in_dtype)
|
||||
else:
|
||||
assert num_chunks > 0
|
||||
workspace13, workspace2, fused_out = self._allocate_buffers(
|
||||
in_dtype,
|
||||
a1q.device,
|
||||
CHUNK_SIZE,
|
||||
M_full,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
return torch.empty_like(a1q, dtype=in_dtype)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
s, e = input_chunk_range(chunk_idx)
|
||||
workspace13, workspace2, fused_out = self._allocate_buffers(
|
||||
in_dtype,
|
||||
a1q.device,
|
||||
M_full,
|
||||
M_full,
|
||||
N,
|
||||
K,
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
activation,
|
||||
)
|
||||
|
||||
c_expert_tokens_meta = self._slice_expert_tokens_metadata(
|
||||
num_chunks,
|
||||
expert_tokens_meta,
|
||||
topk_ids[s:e],
|
||||
local_num_experts,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
c_fused_out = self._slice_output_tensor(
|
||||
fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full
|
||||
)
|
||||
|
||||
self.fused_experts.apply(
|
||||
output=c_fused_out,
|
||||
hidden_states=a1q[s:e],
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights[s:e],
|
||||
topk_ids=topk_ids[s:e],
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=_slice_scales(a1q_scale, s, e),
|
||||
a2_scale=_slice_scales(self.fused_experts.a2_scale, s, e),
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
self.fused_experts.apply(
|
||||
output=fused_out,
|
||||
hidden_states=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=self.fused_experts.a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
|
||||
@@ -337,9 +337,6 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
def supports_expert_map(self):
|
||||
return True
|
||||
|
||||
def supports_chunking(self):
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
|
||||
@@ -83,9 +83,6 @@ class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
|
||||
"This method should not be called."
|
||||
)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -79,9 +79,6 @@ class XPUExperts(mk.FusedMoEExpertsModular):
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -244,8 +244,7 @@ def _get_grouped_gemm_params(
|
||||
device = w1.device
|
||||
|
||||
# Assumes all ranks have the same max_num_batched_tokens
|
||||
max_tokens_across_dp = get_dp_group().world_size * max_tokens
|
||||
max_tokens = min(max_tokens_across_dp, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
max_tokens = get_dp_group().world_size * max_tokens
|
||||
|
||||
# This is the maximum GroupedGemm M size that we expect to run
|
||||
# the grouped_gemm with.
|
||||
|
||||
Reference in New Issue
Block a user