[MoE Refactor][Test] FusedMoE layer test (#24675)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,11 @@ from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUs
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
from vllm.distributed import (
|
||||
cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
@@ -36,10 +40,17 @@ def _set_vllm_config(
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
# When DP is enabled, processes are organized as:
|
||||
# rank = dp_rank * tp_pp_world_size + tp_pp_rank
|
||||
tp_pp_world_size = vllm_config.parallel_config.world_size
|
||||
vllm_config.parallel_config.data_parallel_rank = rank // tp_pp_world_size
|
||||
tp_pp_rank = rank % tp_pp_world_size
|
||||
vllm_config.parallel_config.rank = tp_pp_rank
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
world_size=tp_pp_world_size,
|
||||
rank=tp_pp_rank,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=local_rank,
|
||||
backend="nccl",
|
||||
@@ -59,11 +70,11 @@ def _worker_parallel_launch(
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig | None, Any, P], None],
|
||||
worker: Callable[..., None],
|
||||
vllm_config: VllmConfig | None,
|
||||
env_dict: dict | None,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
worker_kwargs: dict[str, Any],
|
||||
*args: Any,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.accelerator.set_device_index(local_rank)
|
||||
@@ -98,14 +109,17 @@ def _worker_parallel_launch(
|
||||
vllm_config,
|
||||
cpu_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
**worker_kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
if vllm_config is not None:
|
||||
cleanup_dist_env_and_memory()
|
||||
else:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch_with_config(
|
||||
@@ -116,7 +130,6 @@ def parallel_launch_with_config(
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
@@ -127,6 +140,7 @@ def parallel_launch_with_config(
|
||||
worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
kwargs,
|
||||
)
|
||||
+ args,
|
||||
nprocs=world_size,
|
||||
|
||||
Reference in New Issue
Block a user