[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -4,10 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/test_pplx_moe.py`.
|
||||
"""
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -21,10 +18,7 @@ try:
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
@@ -36,6 +30,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
reason="Requires PPLX kernels",
|
||||
)
|
||||
|
||||
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
||||
(222, 2048, 1024)]
|
||||
|
||||
@@ -57,122 +56,6 @@ vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
reason="Requires PPLX kernels",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
"tcp://localhost:29500",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
def parallel_launch_from_env(
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Launches a worker function in parallel across all processes in the current
|
||||
environment. The environment must have the following variables set:
|
||||
- WORLD_SIZE: The total number of processes.
|
||||
- WORLD_LOCAL_SIZE: The number of processes on the current node.
|
||||
- NODE_RANK: The rank of the current
|
||||
- MASTER_ADDR: The address of the master process.
|
||||
- MASTER_PORT: The port of the master process.
|
||||
"""
|
||||
assert not kwargs
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
|
||||
node_rank = int(os.environ["NODE_RANK"])
|
||||
assert "MASTER_ADDR" in os.environ
|
||||
assert "MASTER_PORT" in os.environ
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_local_size,
|
||||
node_rank,
|
||||
"env://",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_local_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
def torch_prepare(
|
||||
a: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user