[MoE Refactor][Test] FusedMoE layer test (#24675)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
bnellnm
2026-04-06 13:17:23 -04:00
committed by GitHub
parent bfdc0a3a99
commit f01482408c
6 changed files with 1858 additions and 55 deletions

View File

@@ -11,7 +11,11 @@ from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUs
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import init_distributed_environment, initialize_model_parallel
from vllm.distributed import (
cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel,
)
from vllm.utils.network_utils import get_open_port
## Parallel Processes Utils
@@ -36,10 +40,17 @@ def _set_vllm_config(
temp_file = tempfile.mkstemp()[1]
# When DP is enabled, processes are organized as:
# rank = dp_rank * tp_pp_world_size + tp_pp_rank
tp_pp_world_size = vllm_config.parallel_config.world_size
vllm_config.parallel_config.data_parallel_rank = rank // tp_pp_world_size
tp_pp_rank = rank % tp_pp_world_size
vllm_config.parallel_config.rank = tp_pp_rank
with set_current_vllm_config(vllm_config):
init_distributed_environment(
world_size=world_size,
rank=rank,
world_size=tp_pp_world_size,
rank=tp_pp_rank,
distributed_init_method=f"file://{temp_file}",
local_rank=local_rank,
backend="nccl",
@@ -59,11 +70,11 @@ def _worker_parallel_launch(
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig | None, Any, P], None],
worker: Callable[..., None],
vllm_config: VllmConfig | None,
env_dict: dict | None,
*args: P.args,
**kwargs: P.kwargs,
worker_kwargs: dict[str, Any],
*args: Any,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.accelerator.set_device_index(local_rank)
@@ -98,14 +109,17 @@ def _worker_parallel_launch(
vllm_config,
cpu_group,
*args,
**kwargs,
**worker_kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
if vllm_config is not None:
cleanup_dist_env_and_memory()
else:
torch.distributed.destroy_process_group()
def parallel_launch_with_config(
@@ -116,7 +130,6 @@ def parallel_launch_with_config(
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
@@ -127,6 +140,7 @@ def parallel_launch_with_config(
worker,
vllm_config,
env_dict,
kwargs,
)
+ args,
nprocs=world_size,