[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
ElizaWszola
2025-06-07 03:26:11 +02:00
committed by GitHub
parent 6e0cd10f72
commit 84166fee97
26 changed files with 918 additions and 409 deletions

View File

@@ -4,10 +4,7 @@
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import dataclasses
import os
import traceback
from typing import Callable, Optional
from typing import Optional
import pytest
import torch
@@ -21,10 +18,7 @@ try:
except ImportError:
has_pplx = False
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config
@@ -36,6 +30,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)]
@@ -57,122 +56,6 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
P = ParamSpec("P")
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
def torch_prepare(
a: torch.Tensor,