[Misc] Fix Current vLLM config is not set. warnings, assert to avoid issues in the future (#31747)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -340,38 +341,42 @@ def async_tp_pass_on_test_model(
|
||||
)
|
||||
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
assert (
|
||||
async_tp_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
async_tp_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
# Set the global vllm_config for TestBackend which calls
|
||||
# get_current_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||
assert (
|
||||
async_tp_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
async_tp_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
|
||||
hidden_states = torch.randn(
|
||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||
)
|
||||
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
hidden_states = torch.randn(
|
||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||
)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
|
||||
assert async_tp_pass.matched_count == 1
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
assert async_tp_pass.matched_count == 1
|
||||
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
|
||||
@@ -430,7 +430,7 @@ def test_cudagraph_sizes_post_init(
|
||||
)
|
||||
|
||||
|
||||
def test_cached_compilation_config():
|
||||
def test_cached_compilation_config(default_vllm_config):
|
||||
import torch
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
|
||||
|
||||
@@ -189,6 +189,17 @@ def dist_init():
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_vllm_config():
|
||||
"""Set a default VllmConfig for tests that directly test CustomOps or pathways
|
||||
that use get_current_vllm_config() outside of a full engine context.
|
||||
"""
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def should_do_global_cleanup_after_test(request) -> bool:
|
||||
"""Allow subdirectories to skip global cleanup by overriding this fixture.
|
||||
|
||||
@@ -458,7 +458,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
)
|
||||
|
||||
|
||||
def test_trtllm_attention_rejects_num_kv_heads_1() -> None:
|
||||
def test_trtllm_attention_rejects_num_kv_heads_1(default_vllm_config) -> None:
|
||||
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.
|
||||
|
||||
When num_kv_heads=1 (MQA), the KV cache strides become degenerate
|
||||
|
||||
@@ -36,7 +36,7 @@ if current_platform.is_rocm():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
def test_mha_attn_platform(device: str):
|
||||
def test_mha_attn_platform(default_vllm_config, device: str):
|
||||
"""
|
||||
Test the attention selector between different platform and device.
|
||||
"""
|
||||
@@ -116,6 +116,7 @@ CUDA_DEVICES = ["cuda"]
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_mha_attn_forward(
|
||||
default_vllm_config,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
@@ -162,6 +163,7 @@ def test_mha_attn_forward(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_mha_attn_varlen_forward(
|
||||
default_vllm_config,
|
||||
var_seq_len: list[int],
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
|
||||
@@ -45,6 +45,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_act_and_mul(
|
||||
default_vllm_config,
|
||||
activation: str,
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
@@ -122,6 +123,7 @@ def test_act_and_mul(
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_activation(
|
||||
default_vllm_config,
|
||||
activation: type[torch.nn.Module],
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
|
||||
@@ -57,6 +57,7 @@ def _apply_qk_norm_rope(
|
||||
@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25])
|
||||
@torch.inference_mode()
|
||||
def test_fused_qk_norm_rope_matches_reference(
|
||||
default_vllm_config,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
is_neox: bool,
|
||||
|
||||
@@ -147,6 +147,7 @@ def ops_impl(
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
default_vllm_config,
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
|
||||
@@ -26,6 +26,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
|
||||
@pytest.mark.parametrize("strided_input", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
default_vllm_config,
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
|
||||
@@ -90,6 +90,7 @@ num_tokens_list = [11, 8192]
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope(
|
||||
default_vllm_config,
|
||||
model_name: str,
|
||||
model_info: MRoPETestInfo,
|
||||
tp_size: int,
|
||||
@@ -159,6 +160,7 @@ def test_mrope(
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope_torch_compile_tracing(
|
||||
default_vllm_config,
|
||||
model_name: str,
|
||||
model_info: MRoPETestInfo,
|
||||
tp_size: int,
|
||||
|
||||
@@ -62,6 +62,7 @@ TENSORS_SHAPES_FN = [
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
default_vllm_config,
|
||||
is_neox_style: bool,
|
||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]],
|
||||
batch_size: int,
|
||||
@@ -123,7 +124,7 @@ def test_rotary_embedding(
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rope_module_cache():
|
||||
def test_rope_module_cache(default_vllm_config):
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
ROPE_THETAS = [10000, 1000000]
|
||||
ROPE_PARAMETERS = (
|
||||
|
||||
@@ -36,6 +36,7 @@ def rotary_embedding_opcheck(
|
||||
@pytest.mark.parametrize("use_key", [True, False])
|
||||
@pytest.mark.parametrize("head_stride_is_contiguous", [True, False])
|
||||
def test_rotary_embedding_opcheck(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
device,
|
||||
max_position,
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
|
||||
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
@@ -24,11 +24,6 @@ USE_BIAS = [True, False]
|
||||
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
DTYPE = [torch.bfloat16]
|
||||
|
||||
_CPU_MOE_ACT = {
|
||||
"silu": SiluAndMul(),
|
||||
"swigluoai": SwigluOAIAndMul(),
|
||||
}
|
||||
|
||||
|
||||
def ref_fused_moe(
|
||||
input: torch.Tensor,
|
||||
@@ -106,6 +101,7 @@ def ref_fused_moe(
|
||||
@pytest.mark.parametrize("act", ACT)
|
||||
@pytest.mark.parametrize("isa", ISA)
|
||||
def test_cpu_fused_moe(
|
||||
default_vllm_config,
|
||||
batch_size: int,
|
||||
expert_num: int,
|
||||
hidden_size: int,
|
||||
|
||||
@@ -468,7 +468,12 @@ def test_fused_moe_wn16(
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_mixtral_moe(
|
||||
dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
use_rocm_aiter: bool,
|
||||
monkeypatch,
|
||||
):
|
||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||
huggingface."""
|
||||
|
||||
@@ -23,7 +23,12 @@ from vllm.utils.torch_utils import set_random_seed
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(
|
||||
batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
|
||||
default_vllm_config,
|
||||
batch_size: int,
|
||||
hidden_dim: int,
|
||||
group_size: int,
|
||||
seed: int,
|
||||
use_ue8m0: bool,
|
||||
) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
@@ -82,7 +87,9 @@ def test_quantfp8_group_functionality(
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
def test_quantfp8_group_multidimensional(
|
||||
default_vllm_config, seed: int, use_ue8m0: bool
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
group_size = 64
|
||||
@@ -135,7 +142,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_edge_cases(seed: int) -> None:
|
||||
def test_quantfp8_group_edge_cases(default_vllm_config, seed: int) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
batch_size = 16
|
||||
|
||||
@@ -102,7 +102,7 @@ SEEDS = [0]
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
def test_w8a8_fp8_fused_moe(default_vllm_config, M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# Initialize int8 quantization parameters
|
||||
factor_for_scale = 1e-2
|
||||
|
||||
@@ -31,6 +31,7 @@ BLOCK_SIZE = 16
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_mul_nvfp4_quant(
|
||||
default_vllm_config,
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
) -> None:
|
||||
|
||||
@@ -39,6 +39,7 @@ def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
default_vllm_config,
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
|
||||
@@ -82,7 +82,7 @@ class DummyLoRAModel(nn.Sequential, SupportsLoRA):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_model() -> nn.Module:
|
||||
def dummy_model(default_vllm_config) -> nn.Module:
|
||||
model = DummyLoRAModel(
|
||||
OrderedDict(
|
||||
[
|
||||
@@ -115,7 +115,7 @@ def dummy_model() -> nn.Module:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_model_gate_up() -> nn.Module:
|
||||
def dummy_model_gate_up(default_vllm_config) -> nn.Module:
|
||||
model = DummyLoRAModel(
|
||||
OrderedDict(
|
||||
[
|
||||
|
||||
@@ -252,7 +252,9 @@ def check_punica_wrapper(punica_wrapper) -> bool:
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
def test_embeddings(
|
||||
default_vllm_config, dist_init, num_loras, device, vocab_size, stage
|
||||
) -> None:
|
||||
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
|
||||
# device, see: https://github.com/triton-lang/triton/issues/2925
|
||||
# Same below.
|
||||
@@ -353,7 +355,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_lm_head_logits_processor(
|
||||
dist_init, num_loras, device, vocab_size, stage
|
||||
default_vllm_config, dist_init, num_loras, device, vocab_size, stage
|
||||
) -> None:
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
@@ -470,6 +472,7 @@ def test_lm_head_logits_processor(
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_linear_replicated(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
num_loras,
|
||||
device,
|
||||
@@ -580,7 +583,7 @@ def test_linear_replicated(
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_linear_parallel(
|
||||
dist_init, num_loras, orientation, fully_shard, device, stage
|
||||
default_vllm_config, dist_init, num_loras, orientation, fully_shard, device, stage
|
||||
) -> None:
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
@@ -705,7 +708,7 @@ def test_linear_parallel(
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_column_parallel_packed(
|
||||
dist_init, num_loras, repeats, fully_shard, device, stage
|
||||
default_vllm_config, dist_init, num_loras, repeats, fully_shard, device, stage
|
||||
) -> None:
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
@@ -851,7 +854,7 @@ def test_column_parallel_packed(
|
||||
@pytest.mark.parametrize(
|
||||
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
|
||||
)
|
||||
def test_vocab_parallel_embedding_indices(tp_size, seed):
|
||||
def test_vocab_parallel_embedding_indices(tp_size, seed, default_vllm_config):
|
||||
random.seed(seed)
|
||||
vocab_size = random.randint(4000, 64000)
|
||||
added_vocab_size = random.randint(0, 1024)
|
||||
|
||||
@@ -111,7 +111,7 @@ def create_packed_lora(
|
||||
return LoRAModel(lora_id, 8, loras)
|
||||
|
||||
|
||||
def test_replace_submodules(dist_init, dummy_model):
|
||||
def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
manager = LoRAModelManager(
|
||||
model,
|
||||
@@ -133,7 +133,7 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model_lora1 = create_lora(
|
||||
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
|
||||
@@ -199,7 +199,9 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
def test_lora_lru_cache_model_manager(
|
||||
default_vllm_config, dist_init, dummy_model, device
|
||||
):
|
||||
model = dummy_model
|
||||
model_lora1 = create_lora(
|
||||
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
|
||||
@@ -289,7 +291,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
|
||||
# This tests just the LRU cache functionality, everything else is
|
||||
# tested in test_lora_model_manager
|
||||
model = dummy_model
|
||||
@@ -415,7 +417,9 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_path):
|
||||
def test_lru_cache_worker_adapter_manager(
|
||||
default_vllm_config, dist_init, dummy_model, device, tmp_path
|
||||
):
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
|
||||
)
|
||||
@@ -529,7 +533,9 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path):
|
||||
def test_worker_adapter_manager(
|
||||
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
|
||||
):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
|
||||
@@ -636,7 +642,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, device):
|
||||
model = dummy_model_gate_up
|
||||
model_lora = create_packed_lora(
|
||||
1,
|
||||
|
||||
@@ -55,7 +55,7 @@ def test_get_draft_quant_config_without_draft_model():
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_fc_layer_quant_config_usage(dist_init, device) -> None:
|
||||
def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) -> None:
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
@@ -73,7 +73,9 @@ def run_intern_vit_test(
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
|
||||
def test_models(
|
||||
default_vllm_config, dist_init, image_assets, model_id, dtype: str
|
||||
) -> None:
|
||||
run_intern_vit_test(
|
||||
image_assets,
|
||||
model_id,
|
||||
|
||||
@@ -92,7 +92,9 @@ def run_radio_test(
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
|
||||
def test_radio(dist_init, image_assets, model_id, dtype: str) -> None:
|
||||
def test_radio(
|
||||
default_vllm_config, dist_init, image_assets, model_id, dtype: str
|
||||
) -> None:
|
||||
run_radio_test(
|
||||
image_assets,
|
||||
model_id,
|
||||
|
||||
@@ -145,18 +145,18 @@ def initialize_dummy_model(
|
||||
model_config: ModelConfig,
|
||||
):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
|
||||
current_device = torch.get_default_device()
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
with set_current_vllm_config(vllm_config=vllm_config):
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
model = model_cls(vllm_config=vllm_config)
|
||||
|
||||
@@ -31,7 +31,7 @@ def test_platform_plugins():
|
||||
)
|
||||
|
||||
|
||||
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
|
||||
def test_oot_custom_op(default_vllm_config, monkeypatch: pytest.MonkeyPatch):
|
||||
# simulate workload by running an example
|
||||
load_general_plugins()
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
@@ -277,6 +277,7 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
# this is the case for marlin as well as per-tensor Fp8MoEMethod
|
||||
@pytest.mark.parametrize("use_marlin", [False]) # skip True
|
||||
def test_fp8_reloading(
|
||||
default_vllm_config,
|
||||
method_cls,
|
||||
is_checkpoint_fp8_serialized,
|
||||
weight_block_size,
|
||||
|
||||
@@ -721,13 +721,34 @@ def init_test_distributed_environment(
|
||||
distributed_init_port: str,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
init_distributed_environment(
|
||||
world_size=pp_size * tp_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank,
|
||||
# Note: This function is often called from Ray worker processes, so we
|
||||
# can't rely on pytest fixtures to set the config. We check if the config
|
||||
# is already set and only create a default one if needed.
|
||||
from vllm.config import (
|
||||
VllmConfig,
|
||||
get_current_vllm_config_or_none,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
|
||||
if get_current_vllm_config_or_none() is not None:
|
||||
# Config already set, use it directly
|
||||
init_distributed_environment(
|
||||
world_size=pp_size * tp_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank,
|
||||
)
|
||||
else:
|
||||
# No config set, create a default one for the test
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
init_distributed_environment(
|
||||
world_size=pp_size * tp_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank,
|
||||
)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
|
||||
|
||||
|
||||
@@ -556,7 +556,7 @@ def _test_backend_correctness(
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
def test_causal_backend_correctness(
|
||||
batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
default_vllm_config, batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
):
|
||||
"""Test backend's correctness with causal attention."""
|
||||
|
||||
|
||||
@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
||||
],
|
||||
)
|
||||
def test_mamba_layers_get_attn_backend(
|
||||
dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
layer_class,
|
||||
init_kwargs,
|
||||
expected_backend,
|
||||
expected_mamba_type,
|
||||
):
|
||||
"""Test that Mamba-like layers return the correct attention backend."""
|
||||
layer = layer_class(**init_kwargs)
|
||||
|
||||
@@ -394,7 +394,11 @@ def run_attention_backend(
|
||||
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
|
||||
def test_backend_correctness(
|
||||
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
batch_spec_name: str,
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
):
|
||||
"""
|
||||
Test that all backends produce similar outputs to a reference implementation
|
||||
|
||||
@@ -124,7 +124,12 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_sparse_backend_decode_correctness(
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
batch_name,
|
||||
kv_cache_dtype,
|
||||
tensor_parallel_size,
|
||||
workspace_init,
|
||||
):
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")
|
||||
|
||||
@@ -21,7 +21,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
|
||||
def test_rms_norm_batch_invariant_vs_standard(
|
||||
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
|
||||
default_vllm_config,
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
eps: float,
|
||||
):
|
||||
"""
|
||||
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
|
||||
@@ -68,7 +72,9 @@ def test_rms_norm_batch_invariant_vs_standard(
|
||||
@pytest.mark.parametrize("batch_size", [1, 16, 128])
|
||||
@pytest.mark.parametrize("seq_len", [1, 32, 512])
|
||||
@pytest.mark.parametrize("hidden_size", [2048, 4096])
|
||||
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
|
||||
def test_rms_norm_3d_input(
|
||||
default_vllm_config, batch_size: int, seq_len: int, hidden_size: int
|
||||
):
|
||||
"""
|
||||
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
|
||||
|
||||
@@ -107,7 +113,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
def test_rms_norm_numerical_stability():
|
||||
def test_rms_norm_numerical_stability(default_vllm_config):
|
||||
"""
|
||||
Test RMS norm numerical stability with extreme values.
|
||||
|
||||
@@ -167,7 +173,7 @@ def test_rms_norm_numerical_stability():
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
def test_rms_norm_formula():
|
||||
def test_rms_norm_formula(default_vllm_config):
|
||||
"""
|
||||
Test that RMS norm follows the correct mathematical formula.
|
||||
|
||||
@@ -201,7 +207,7 @@ def test_rms_norm_formula():
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
|
||||
def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||
def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
|
||||
"""
|
||||
Test RMS norm with various hidden sizes to ensure block size handling.
|
||||
|
||||
@@ -238,7 +244,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
def test_rms_norm_determinism():
|
||||
def test_rms_norm_determinism(default_vllm_config):
|
||||
"""
|
||||
Test that batch-invariant RMS norm produces deterministic results.
|
||||
|
||||
|
||||
@@ -299,6 +299,7 @@ def test_prompt_less_than_block_size():
|
||||
)
|
||||
def test_kv_transfer_handshake(dist_init):
|
||||
"""Unit test for basic NixlConnector interface functionality."""
|
||||
from vllm.config import set_current_vllm_config
|
||||
|
||||
# Test setup, we creates a scheduler that contains a NixlConnector
|
||||
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
|
||||
@@ -308,81 +309,82 @@ def test_kv_transfer_handshake(dist_init):
|
||||
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Create two NixlConnector of role WORKER, one is the worker of
|
||||
# the scheduler (prefill), the other is a worker of decode instance.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Create two NixlConnector of role WORKER, one is the worker of
|
||||
# the scheduler (prefill), the other is a worker of decode instance.
|
||||
|
||||
# Prefill connector will register KV cache to populate proper handshake
|
||||
# metadata.
|
||||
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
prefill_connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Simulate EngineCore initialization that would gather connector
|
||||
# metadata from all workers
|
||||
metadata = prefill_connector.get_handshake_metadata()
|
||||
|
||||
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
|
||||
|
||||
# The scheduler connector expects metadata to be in
|
||||
# dict[int, KVConnectorHandshakeMetadata], where the first key is
|
||||
# the dp_rank, the second key is the tp_rank.
|
||||
scheduler_connector = scheduler.get_kv_connector()
|
||||
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
|
||||
|
||||
# Simulate a request that finishes prefill, which returns
|
||||
# corresponding NixlConnectorMetadata for decode instance.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
|
||||
request, [0, 1, 2]
|
||||
)
|
||||
assert delay
|
||||
|
||||
# Decode connector will be able to create handshake with the prefill connector.
|
||||
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# Here we are testing the retrieval of NIXLAgentMetadata.
|
||||
# Knowing the implementation detail, we override the add_remote_agent
|
||||
# to validate the metadata received is the same as the one in prefill_connector.
|
||||
with patch.object(
|
||||
decode_connector.connector_worker, "add_remote_agent"
|
||||
) as mock_add_remote_agent:
|
||||
mock_add_remote_agent.return_type = "remote_agent"
|
||||
|
||||
decode_connector.connector_worker._nixl_handshake(
|
||||
kv_connector_metadata["remote_host"],
|
||||
kv_connector_metadata["remote_port"],
|
||||
kv_connector_metadata["tp_size"],
|
||||
kv_connector_metadata["remote_engine_id"],
|
||||
# Prefill connector will register KV cache to populate proper handshake
|
||||
# metadata.
|
||||
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
prefill_connector.register_kv_caches(kv_caches)
|
||||
|
||||
received_metadata = mock_add_remote_agent.call_args.args
|
||||
assert received_metadata[0] == expected_agent_metadata
|
||||
assert received_metadata[1] == 0 # remote_tp_rank
|
||||
assert received_metadata[2] == 1 # remote_tp_size
|
||||
# Simulate EngineCore initialization that would gather connector
|
||||
# metadata from all workers
|
||||
metadata = prefill_connector.get_handshake_metadata()
|
||||
|
||||
# Need to shutdown the background thread to release NIXL side channel port
|
||||
scheduler_connector.shutdown()
|
||||
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
|
||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
|
||||
|
||||
# The scheduler connector expects metadata to be in
|
||||
# dict[int, KVConnectorHandshakeMetadata], where the first key is
|
||||
# the dp_rank, the second key is the tp_rank.
|
||||
scheduler_connector = scheduler.get_kv_connector()
|
||||
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
|
||||
|
||||
# Simulate a request that finishes prefill, which returns
|
||||
# corresponding NixlConnectorMetadata for decode instance.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
|
||||
request, [0, 1, 2]
|
||||
)
|
||||
assert delay
|
||||
|
||||
# Decode connector will be able to create handshake with the prefill connector.
|
||||
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# Here we are testing the retrieval of NIXLAgentMetadata.
|
||||
# Knowing the implementation detail, we override the add_remote_agent
|
||||
# to validate the metadata received is the same as the one in prefill_connector.
|
||||
with patch.object(
|
||||
decode_connector.connector_worker, "add_remote_agent"
|
||||
) as mock_add_remote_agent:
|
||||
mock_add_remote_agent.return_type = "remote_agent"
|
||||
|
||||
decode_connector.connector_worker._nixl_handshake(
|
||||
kv_connector_metadata["remote_host"],
|
||||
kv_connector_metadata["remote_port"],
|
||||
kv_connector_metadata["tp_size"],
|
||||
kv_connector_metadata["remote_engine_id"],
|
||||
)
|
||||
|
||||
received_metadata = mock_add_remote_agent.call_args.args
|
||||
assert received_metadata[0] == expected_agent_metadata
|
||||
assert received_metadata[1] == 0 # remote_tp_rank
|
||||
assert received_metadata[2] == 1 # remote_tp_size
|
||||
|
||||
# Need to shutdown the background thread to release NIXL side channel port
|
||||
scheduler_connector.shutdown()
|
||||
|
||||
|
||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
@@ -458,6 +460,7 @@ class TestNixlHandshake:
|
||||
)
|
||||
def test_multi_xfer_one_engine(
|
||||
self,
|
||||
default_vllm_config,
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init,
|
||||
):
|
||||
@@ -547,6 +550,7 @@ class TestNixlHandshake:
|
||||
)
|
||||
def test_async_load_kv(
|
||||
self,
|
||||
default_vllm_config,
|
||||
# Fixture that initializes the distributed environment.
|
||||
dist_init,
|
||||
# Simulate consumer-producer TP sizes.
|
||||
@@ -605,7 +609,7 @@ class TestNixlHandshake:
|
||||
)
|
||||
@pytest.mark.parametrize("local_tp_size", [1, 2])
|
||||
def test_prefill_tp_size_greater_than_decode_tp_size(
|
||||
self, local_tp_size: int, dist_init
|
||||
self, local_tp_size: int, default_vllm_config, dist_init
|
||||
):
|
||||
"""
|
||||
Verify remote TP > local TP handshake succeeds with different
|
||||
@@ -670,7 +674,7 @@ class TestNixlHandshake:
|
||||
)
|
||||
@pytest.mark.parametrize("local_tp_size", [1, 2])
|
||||
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
|
||||
self, local_tp_size: int, dist_init
|
||||
self, local_tp_size: int, default_vllm_config, dist_init
|
||||
):
|
||||
"""
|
||||
Verify remote TP > local TP handshake succeeds with different
|
||||
@@ -770,6 +774,7 @@ class TestNixlHandshake:
|
||||
)
|
||||
def test_concurrent_load_kv(
|
||||
self,
|
||||
default_vllm_config,
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init,
|
||||
):
|
||||
@@ -830,7 +835,9 @@ class TestNixlHandshake:
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
|
||||
def test_handshake_fails_on_kv_cache_layout_mismatch(
|
||||
self, default_vllm_config, dist_init
|
||||
):
|
||||
"""
|
||||
Verify that adding a remote agent fails if kv_cache_layout differs.
|
||||
This test is only relevant for heterogeneous TP.
|
||||
@@ -879,7 +886,7 @@ class TestNixlHandshake:
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
|
||||
self, dist_init
|
||||
self, default_vllm_config, dist_init
|
||||
):
|
||||
"""
|
||||
Verify that adding a remote agent fails if kv_cache_layout differs.
|
||||
@@ -934,7 +941,7 @@ class TestNixlHandshake:
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_kv_connector_stats(dist_init):
|
||||
def test_kv_connector_stats(default_vllm_config, dist_init):
|
||||
"""Test that KV transfer stats are properly recorded and retrieved."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
@@ -1357,7 +1364,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
"TRITON_ATTN",
|
||||
],
|
||||
)
|
||||
def test_register_kv_caches(dist_init, attn_backend):
|
||||
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
|
||||
"""
|
||||
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
||||
correct data.
|
||||
@@ -1518,7 +1525,9 @@ class FakePlatform(Platform):
|
||||
("oot", "VRAM"),
|
||||
],
|
||||
)
|
||||
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type):
|
||||
def test_kv_buffer_to_nixl_memory_types(
|
||||
default_vllm_config, dist_init, kv_buffer_device, nixl_memory_type
|
||||
):
|
||||
"""
|
||||
Test that register_kv_caches() passes the correct memory types from the
|
||||
config to the nixl_wrapper.
|
||||
@@ -1563,7 +1572,7 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_shutdown_cleans_up_resources(dist_init):
|
||||
def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
|
||||
"""Test that shutdown() properly cleans up all resources."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
@@ -1622,7 +1631,7 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_aborted_request_removed_from_worker_in_batch(dist_init):
|
||||
def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init):
|
||||
"""
|
||||
Create and schedule a request so that P adds it to in-batch tracking via
|
||||
the real scheduler, then simulate an abort (request not in next scheduler
|
||||
@@ -1731,7 +1740,7 @@ class FailingNixlWrapper(FakeNixlWrapper):
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FailingNixlWrapper,
|
||||
)
|
||||
def test_handshake_failure_returns_finished(dist_init):
|
||||
def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
|
||||
"""Test that handshake failures mark blocks invalid and return via get_finished."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
@@ -1780,7 +1789,7 @@ def test_handshake_failure_returns_finished(dist_init):
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FailingNixlWrapper,
|
||||
)
|
||||
def test_transfer_setup_failure_returns_finished(dist_init):
|
||||
def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init):
|
||||
"""Test that transfer setup failures mark blocks invalid
|
||||
and return via get_finished."""
|
||||
vllm_config = create_vllm_config()
|
||||
@@ -1855,6 +1864,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_compatibility_hash_validation(
|
||||
default_vllm_config,
|
||||
dist_init,
|
||||
mismatch_type,
|
||||
config_overrides,
|
||||
@@ -1967,7 +1977,7 @@ def test_compatibility_hash_validation(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_handshake_decode_errors(dist_init, error_scenario):
|
||||
def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario):
|
||||
"""
|
||||
Test that msgspec decode errors are properly handled during handshake.
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ NUM_MAPPINGS = [3]
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_transfer(
|
||||
default_vllm_config,
|
||||
gpu_to_cpu: bool,
|
||||
num_mappings: int,
|
||||
head_size: int,
|
||||
|
||||
@@ -112,15 +112,16 @@ def get_vllm_config():
|
||||
@pytest.fixture
|
||||
def model_runner():
|
||||
vllm_config = get_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
|
||||
num_heads, head_size, 0.1
|
||||
)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
initialize_kv_cache(runner)
|
||||
return runner
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model_config = vllm_config.model_config
|
||||
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
|
||||
num_heads, head_size, 0.1
|
||||
)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
initialize_kv_cache(runner)
|
||||
yield runner
|
||||
|
||||
|
||||
model_runner_2 = model_runner
|
||||
@@ -546,7 +547,7 @@ def test_reload_weights_before_load_model(model_runner):
|
||||
model_runner.reload_weights()
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
@@ -573,7 +574,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
||||
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_config):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
@@ -600,7 +601,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
||||
def test_init_kv_cache_with_kv_sharing_target_same_as_current(default_vllm_config):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
@@ -627,7 +628,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
||||
assert fwd_context is not None
|
||||
|
||||
|
||||
def test_init_kv_cache_without_kv_sharing():
|
||||
def test_init_kv_cache_without_kv_sharing(default_vllm_config):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
@@ -694,7 +695,7 @@ def test_init_kv_cache_without_kv_sharing():
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_valid():
|
||||
def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
@@ -1047,7 +1048,7 @@ def test_input_batch_with_kernel_block_sizes():
|
||||
assert block_table.block_size == kernel_size
|
||||
|
||||
|
||||
def test_hybrid_cache_integration(model_runner, dist_init):
|
||||
def test_hybrid_cache_integration(default_vllm_config, dist_init):
|
||||
"""Test hybrid cache architecture integration with GPUModelRunner."""
|
||||
# Create a new model runner with hybrid cache configuration
|
||||
vllm_config = get_vllm_config()
|
||||
|
||||
@@ -6,14 +6,14 @@ import torch
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
def test_bind_kv_cache(default_vllm_config):
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
ctx = {
|
||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.0.self_attn": Attention(32, 128, 0.1, prefix="layers.0.self_attn"),
|
||||
"layers.1.self_attn": Attention(32, 128, 0.1, prefix="layers.1.self_attn"),
|
||||
"layers.2.self_attn": Attention(32, 128, 0.1, prefix="layers.2.self_attn"),
|
||||
"layers.3.self_attn": Attention(32, 128, 0.1, prefix="layers.3.self_attn"),
|
||||
}
|
||||
kv_cache = {
|
||||
"layers.0.self_attn": torch.zeros((1,)),
|
||||
@@ -34,13 +34,13 @@ def test_bind_kv_cache():
|
||||
assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
def test_bind_kv_cache_non_attention(default_vllm_config):
|
||||
from vllm.attention.layer import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
"model.layers.20.attn": Attention(32, 128, 0.1),
|
||||
"model.layers.28.attn": Attention(32, 128, 0.1),
|
||||
"model.layers.20.attn": Attention(32, 128, 0.1, prefix="model.layers.20.attn"),
|
||||
"model.layers.28.attn": Attention(32, 128, 0.1, prefix="model.layers.28.attn"),
|
||||
}
|
||||
kv_cache = {
|
||||
"model.layers.20.attn": torch.zeros((1,)),
|
||||
|
||||
@@ -59,10 +59,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
)
|
||||
|
||||
# 2. override if passed by environment or config
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.attention_config.flash_attn_version is not None:
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.flash_attn_version is not None
|
||||
):
|
||||
fa_version = vllm_config.attention_config.flash_attn_version
|
||||
|
||||
# 3. fallback for unsupported combinations
|
||||
|
||||
@@ -42,6 +42,7 @@ from vllm.config.vllm import (
|
||||
VllmConfig,
|
||||
get_cached_compilation_config,
|
||||
get_current_vllm_config,
|
||||
get_current_vllm_config_or_none,
|
||||
get_layers_from_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
@@ -105,6 +106,7 @@ __all__ = [
|
||||
"VllmConfig",
|
||||
"get_cached_compilation_config",
|
||||
"get_current_vllm_config",
|
||||
"get_current_vllm_config_or_none",
|
||||
"set_current_vllm_config",
|
||||
"get_layers_from_vllm_config",
|
||||
]
|
||||
|
||||
@@ -1441,13 +1441,18 @@ def get_cached_compilation_config():
|
||||
|
||||
def get_current_vllm_config() -> VllmConfig:
|
||||
if _current_vllm_config is None:
|
||||
# in ci, usually when we test custom ops/modules directly,
|
||||
# we don't set the vllm config. In that case, we set a default
|
||||
# config.
|
||||
# Use stack level 2 so the log contains the line of the caller,
|
||||
# so it's easier to track down the source of the warning.
|
||||
logger.warning("Current vLLM config is not set.", stacklevel=2)
|
||||
return VllmConfig()
|
||||
raise AssertionError(
|
||||
"Current vLLM config is not set. This typically means "
|
||||
"get_current_vllm_config() was called outside of a "
|
||||
"set_current_vllm_config() context, or a CustomOp was instantiated "
|
||||
"at module import time or model forward time when config is not set. "
|
||||
"For tests that directly test custom ops/modules, use the "
|
||||
"'default_vllm_config' pytest fixture from tests/conftest.py."
|
||||
)
|
||||
return _current_vllm_config
|
||||
|
||||
|
||||
def get_current_vllm_config_or_none() -> VllmConfig | None:
|
||||
return _current_vllm_config
|
||||
|
||||
|
||||
|
||||
@@ -117,9 +117,9 @@ class DeviceCommunicatorBase:
|
||||
|
||||
use_ep = False
|
||||
all2all_backend = None
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config()
|
||||
config = get_current_vllm_config_or_none()
|
||||
if config is not None:
|
||||
# as long as we use data parallel (coupled data parallel
|
||||
# where all data parallel ranks execute forward together),
|
||||
|
||||
@@ -9,7 +9,7 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -184,7 +184,7 @@ class QuickAllReduce:
|
||||
)
|
||||
return
|
||||
self.qr_quant_level = QuickReduceRegime[regime_str]
|
||||
vllm_config = get_current_vllm_config()
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and hasattr(vllm_config, "model_config")
|
||||
|
||||
@@ -1177,9 +1177,9 @@ def init_distributed_environment(
|
||||
distributed_init_method,
|
||||
backend,
|
||||
)
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config()
|
||||
config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
@@ -1251,7 +1251,7 @@ def init_distributed_environment(
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||
if config.parallel_config.nnodes > 1:
|
||||
if config is not None and config.parallel_config.nnodes > 1:
|
||||
_NODE_COUNT = config.parallel_config.nnodes
|
||||
else:
|
||||
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
||||
@@ -1260,7 +1260,7 @@ def init_distributed_environment(
|
||||
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
||||
"world group already initialized with a different world size"
|
||||
)
|
||||
if config.parallel_config.nnodes_within_dp > 1:
|
||||
if config is not None and config.parallel_config.nnodes_within_dp > 1:
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
world_size_inner_dp = parallel_config.world_size
|
||||
group_ranks = [
|
||||
@@ -1316,9 +1316,9 @@ def initialize_model_parallel(
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config()
|
||||
config = get_current_vllm_config_or_none()
|
||||
if config is not None:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
|
||||
|
||||
@@ -13,10 +13,28 @@ from vllm.model_executor.layers.quantization.utils.layer_utils import replace_pa
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
_CPU_MOE_LAYER_CACHE = {}
|
||||
_CPU_MOE_ACT = {
|
||||
"silu": SiluAndMul(),
|
||||
"swigluoai": SwigluOAIAndMul(),
|
||||
}
|
||||
|
||||
|
||||
class _LazyActivationDict(dict):
|
||||
"""Lazily instantiate activation functions on first access.
|
||||
|
||||
Avoids triggering CustomOp.__init__() at module import time,
|
||||
which would call get_current_vllm_config() before config is set.
|
||||
"""
|
||||
|
||||
_factories: dict[str, type[SiluAndMul] | type[SwigluOAIAndMul]] = {
|
||||
"silu": SiluAndMul,
|
||||
"swigluoai": SwigluOAIAndMul,
|
||||
}
|
||||
|
||||
def __missing__(self, key: str) -> SiluAndMul | SwigluOAIAndMul:
|
||||
if key not in self._factories:
|
||||
raise KeyError(f"{key} is not a supported activation")
|
||||
self[key] = self._factories[key]()
|
||||
return self[key]
|
||||
|
||||
|
||||
_CPU_MOE_ACT = _LazyActivationDict()
|
||||
|
||||
|
||||
def grouped_topk(
|
||||
@@ -212,7 +230,7 @@ class CPUFusedMOE:
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation in _CPU_MOE_ACT, f"{activation} is not supported."
|
||||
assert activation in _CPU_MOE_ACT._factories, f"{activation} is not supported."
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
|
||||
@@ -540,6 +540,20 @@ class FusedMoE(CustomOp):
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.activation = activation
|
||||
|
||||
self._grouped_topk_impl: GroupedTopk | None = None
|
||||
if self.use_grouped_topk:
|
||||
assert self.num_expert_group is not None
|
||||
assert self.topk_group is not None
|
||||
self._grouped_topk_impl = GroupedTopk(
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
)
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError(
|
||||
"Only softmax scoring function is supported for non-grouped topk."
|
||||
@@ -1588,19 +1602,8 @@ class FusedMoE(CustomOp):
|
||||
|
||||
# DeepSeekv2 uses grouped_top_k
|
||||
elif self.use_grouped_topk and valid_grouping():
|
||||
assert self.topk_group is not None
|
||||
assert self.num_expert_group is not None
|
||||
grouped_topk_impl = GroupedTopk(
|
||||
topk=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
num_expert_group=self.num_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
assert self._grouped_topk_impl is not None
|
||||
topk_weights, topk_ids = self._grouped_topk_impl(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
|
||||
@@ -339,15 +339,11 @@ def apply_rotary_pos_emb_flashatt(
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
apply_rotary_emb: ApplyRotaryEmb,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
|
||||
apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
q_embed = apply_rotary_emb(q, cos, sin)
|
||||
k_embed = apply_rotary_emb(k, cos, sin)
|
||||
|
||||
@@ -410,6 +406,11 @@ class KeyeSiglipAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -448,7 +449,7 @@ class KeyeSiglipAttention(nn.Module):
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
|
||||
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin, self.apply_rotary_emb)
|
||||
v = v.view(
|
||||
*v.shape[:-1],
|
||||
self.num_kv_heads,
|
||||
|
||||
@@ -152,16 +152,12 @@ def apply_rotary_pos_emb(
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_flash_attn_backend: bool = False,
|
||||
is_flash_attn_backend: bool,
|
||||
apply_rotary_emb: ApplyRotaryEmb,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
|
||||
apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
if is_flash_attn_backend and current_platform.is_cuda():
|
||||
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
|
||||
elif is_flash_attn_backend and current_platform.is_rocm():
|
||||
@@ -235,6 +231,11 @@ class Siglip2Attention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -260,6 +261,7 @@ class Siglip2Attention(nn.Module):
|
||||
cos,
|
||||
sin,
|
||||
self.attn.is_flash_attn_backend,
|
||||
self.apply_rotary_emb,
|
||||
)
|
||||
queries = queries.squeeze(0)
|
||||
keys = keys.squeeze(0)
|
||||
|
||||
@@ -14,7 +14,7 @@ import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.distributed import (
|
||||
ensure_model_parallel_initialized,
|
||||
@@ -268,7 +268,9 @@ class Worker(WorkerBase):
|
||||
# to hijack tensor allocation.
|
||||
def load_model(self) -> None:
|
||||
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
|
||||
with self._maybe_get_memory_pool_context(tag="weights"):
|
||||
with self._maybe_get_memory_pool_context(
|
||||
tag="weights"
|
||||
) and set_current_vllm_config(self.vllm_config):
|
||||
self.model_runner.load_model(eep_scale_up=eep_scale_up)
|
||||
|
||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||
|
||||
Reference in New Issue
Block a user