[1/N] Elastic EP Milestone 2 (#34861)
Signed-off-by: Yongji Wu <wuyongji317@gmail.com> Signed-off-by: Itay Alroy <ialroy@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com> Co-authored-by: Yongji Wu <wuyongji317@gmail.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
@@ -895,6 +895,36 @@ def compare_all_settings(
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ensure_current_vllm_config():
|
||||
"""Ensures a vllm config is set for the duration of the context.
|
||||
|
||||
If a config is already set, this is a no-op. Otherwise, it creates a default
|
||||
VllmConfig and sets it for the duration of the context.
|
||||
|
||||
Used for tests that call functions which require a vllm config but don't
|
||||
need a specific config.
|
||||
|
||||
Example:
|
||||
with ensure_current_vllm_config():
|
||||
init_distributed_environment(...)
|
||||
ensure_model_parallel_initialized(...)
|
||||
"""
|
||||
from vllm.config import (
|
||||
VllmConfig,
|
||||
get_current_vllm_config_or_none,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
|
||||
if get_current_vllm_config_or_none() is not None:
|
||||
# Config already set, just yield
|
||||
yield
|
||||
else:
|
||||
# No config set, create a default one for the duration
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
yield
|
||||
|
||||
|
||||
def init_test_distributed_environment(
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
@@ -921,6 +951,7 @@ def init_test_distributed_environment(
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank,
|
||||
)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
else:
|
||||
# No config set, create a default one for the test
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
@@ -930,7 +961,7 @@ def init_test_distributed_environment(
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank,
|
||||
)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
|
||||
Reference in New Issue
Block a user