[Misc] Fix Current vLLM config is not set. warnings, assert to avoid issues in the future (#31747)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Lucas Wilkinson
2026-01-08 18:20:49 -05:00
committed by GitHub
parent 5d3b6097ad
commit 6cdf015c3c
48 changed files with 380 additions and 240 deletions

View File

@@ -15,6 +15,7 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.distributed import (
tensor_model_parallel_all_gather,
@@ -340,38 +341,42 @@ def async_tp_pass_on_test_model(
)
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)
assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
# Set the global vllm_config for TestBackend which calls
# get_current_vllm_config()
with set_current_vllm_config(vllm_config):
backend = TestBackend(async_tp_pass)
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
hidden_states = torch.randn(
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
)
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
hidden_states = torch.randn(
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
assert async_tp_pass.matched_count == 1
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
assert async_tp_pass.matched_count == 1
# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist
backend.check_after_ops(model.ops_in_model_after())
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist
backend.check_after_ops(model.ops_in_model_after())
@create_new_process_for_each_test()

View File

@@ -430,7 +430,7 @@ def test_cudagraph_sizes_post_init(
)
def test_cached_compilation_config():
def test_cached_compilation_config(default_vllm_config):
import torch
from torch._inductor.utils import run_and_get_code

View File

@@ -189,6 +189,17 @@ def dist_init():
cleanup_dist_env_and_memory()
@pytest.fixture
def default_vllm_config():
"""Set a default VllmConfig for tests that directly test CustomOps or pathways
that use get_current_vllm_config() outside of a full engine context.
"""
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
yield
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.

View File

@@ -458,7 +458,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
)
def test_trtllm_attention_rejects_num_kv_heads_1() -> None:
def test_trtllm_attention_rejects_num_kv_heads_1(default_vllm_config) -> None:
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.
When num_kv_heads=1 (MQA), the KV cache strides become degenerate

View File

@@ -36,7 +36,7 @@ if current_platform.is_rocm():
@pytest.mark.parametrize("device", devices)
def test_mha_attn_platform(device: str):
def test_mha_attn_platform(default_vllm_config, device: str):
"""
Test the attention selector between different platform and device.
"""
@@ -116,6 +116,7 @@ CUDA_DEVICES = ["cuda"]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward(
default_vllm_config,
batch_size: int,
seq_len: int,
num_heads: int,
@@ -162,6 +163,7 @@ def test_mha_attn_forward(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_varlen_forward(
default_vllm_config,
var_seq_len: list[int],
num_heads: int,
num_kv_heads: int,

View File

@@ -45,6 +45,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_act_and_mul(
default_vllm_config,
activation: str,
num_tokens: int,
d: int,
@@ -122,6 +123,7 @@ def test_act_and_mul(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_activation(
default_vllm_config,
activation: type[torch.nn.Module],
num_tokens: int,
d: int,

View File

@@ -57,6 +57,7 @@ def _apply_qk_norm_rope(
@pytest.mark.parametrize("rotary_ratio", [1.0, 0.5, 0.25])
@torch.inference_mode()
def test_fused_qk_norm_rope_matches_reference(
default_vllm_config,
device: str,
dtype: torch.dtype,
is_neox: bool,

View File

@@ -147,6 +147,7 @@ def ops_impl(
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_rms_norm(
default_vllm_config,
num_tokens: int,
hidden_size: int,
add_residual: bool,

View File

@@ -26,6 +26,7 @@ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 e
@pytest.mark.parametrize("strided_input", [False, True])
@torch.inference_mode()
def test_rms_norm(
default_vllm_config,
num_tokens: int,
hidden_size: int,
add_residual: bool,

View File

@@ -90,6 +90,7 @@ num_tokens_list = [11, 8192]
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(
default_vllm_config,
model_name: str,
model_info: MRoPETestInfo,
tp_size: int,
@@ -159,6 +160,7 @@ def test_mrope(
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope_torch_compile_tracing(
default_vllm_config,
model_name: str,
model_info: MRoPETestInfo,
tp_size: int,

View File

@@ -62,6 +62,7 @@ TENSORS_SHAPES_FN = [
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_rotary_embedding(
default_vllm_config,
is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]],
batch_size: int,
@@ -123,7 +124,7 @@ def test_rotary_embedding(
@torch.inference_mode()
def test_rope_module_cache():
def test_rope_module_cache(default_vllm_config):
MAX_POSITIONS = [123, 1234]
ROPE_THETAS = [10000, 1000000]
ROPE_PARAMETERS = (

View File

@@ -36,6 +36,7 @@ def rotary_embedding_opcheck(
@pytest.mark.parametrize("use_key", [True, False])
@pytest.mark.parametrize("head_stride_is_contiguous", [True, False])
def test_rotary_embedding_opcheck(
default_vllm_config,
dist_init,
device,
max_position,

View File

@@ -6,7 +6,7 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@@ -24,11 +24,6 @@ USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def ref_fused_moe(
input: torch.Tensor,
@@ -106,6 +101,7 @@ def ref_fused_moe(
@pytest.mark.parametrize("act", ACT)
@pytest.mark.parametrize("isa", ISA)
def test_cpu_fused_moe(
default_vllm_config,
batch_size: int,
expert_num: int,
hidden_size: int,

View File

@@ -468,7 +468,12 @@ def test_fused_moe_wn16(
)
@torch.inference_mode()
def test_mixtral_moe(
dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
default_vllm_config,
dist_init,
dtype: torch.dtype,
padding: bool,
use_rocm_aiter: bool,
monkeypatch,
):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""

View File

@@ -23,7 +23,12 @@ from vllm.utils.torch_utils import set_random_seed
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_functionality(
batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
default_vllm_config,
batch_size: int,
hidden_dim: int,
group_size: int,
seed: int,
use_ue8m0: bool,
) -> None:
"""Test QuantFP8 group quantization with various configurations.
@@ -82,7 +87,9 @@ def test_quantfp8_group_functionality(
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
def test_quantfp8_group_multidimensional(
default_vllm_config, seed: int, use_ue8m0: bool
) -> None:
set_random_seed(seed)
group_size = 64
@@ -135,7 +142,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_edge_cases(seed: int) -> None:
def test_quantfp8_group_edge_cases(default_vllm_config, seed: int) -> None:
set_random_seed(seed)
batch_size = 16

View File

@@ -102,7 +102,7 @@ SEEDS = [0]
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
def test_w8a8_fp8_fused_moe(default_vllm_config, M, N, K, E, topk, dtype, seed):
torch.manual_seed(seed)
# Initialize int8 quantization parameters
factor_for_scale = 1e-2

View File

@@ -31,6 +31,7 @@ BLOCK_SIZE = 16
@pytest.mark.parametrize("shape", SHAPES)
@torch.inference_mode()
def test_silu_mul_nvfp4_quant(
default_vllm_config,
dtype: torch.dtype,
shape: tuple[int, int],
) -> None:

View File

@@ -39,6 +39,7 @@ def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
default_vllm_config,
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,

View File

@@ -82,7 +82,7 @@ class DummyLoRAModel(nn.Sequential, SupportsLoRA):
@pytest.fixture
def dummy_model() -> nn.Module:
def dummy_model(default_vllm_config) -> nn.Module:
model = DummyLoRAModel(
OrderedDict(
[
@@ -115,7 +115,7 @@ def dummy_model() -> nn.Module:
@pytest.fixture
def dummy_model_gate_up() -> nn.Module:
def dummy_model_gate_up(default_vllm_config) -> nn.Module:
model = DummyLoRAModel(
OrderedDict(
[

View File

@@ -252,7 +252,9 @@ def check_punica_wrapper(punica_wrapper) -> bool:
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
@pytest.mark.parametrize("stage", STAGES)
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
def test_embeddings(
default_vllm_config, dist_init, num_loras, device, vocab_size, stage
) -> None:
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
@@ -353,7 +355,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor(
dist_init, num_loras, device, vocab_size, stage
default_vllm_config, dist_init, num_loras, device, vocab_size, stage
) -> None:
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
@@ -470,6 +472,7 @@ def test_lm_head_logits_processor(
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(
default_vllm_config,
dist_init,
num_loras,
device,
@@ -580,7 +583,7 @@ def test_linear_replicated(
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_parallel(
dist_init, num_loras, orientation, fully_shard, device, stage
default_vllm_config, dist_init, num_loras, orientation, fully_shard, device, stage
) -> None:
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
@@ -705,7 +708,7 @@ def test_linear_parallel(
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_column_parallel_packed(
dist_init, num_loras, repeats, fully_shard, device, stage
default_vllm_config, dist_init, num_loras, repeats, fully_shard, device, stage
) -> None:
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)
@@ -851,7 +854,7 @@ def test_column_parallel_packed(
@pytest.mark.parametrize(
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))
)
def test_vocab_parallel_embedding_indices(tp_size, seed):
def test_vocab_parallel_embedding_indices(tp_size, seed, default_vllm_config):
random.seed(seed)
vocab_size = random.randint(4000, 64000)
added_vocab_size = random.randint(0, 1024)

View File

@@ -111,7 +111,7 @@ def create_packed_lora(
return LoRAModel(lora_id, 8, loras)
def test_replace_submodules(dist_init, dummy_model):
def test_replace_submodules(default_vllm_config, dist_init, dummy_model):
model = dummy_model
manager = LoRAModelManager(
model,
@@ -133,7 +133,7 @@ def test_replace_submodules(dist_init, dummy_model):
@pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
model = dummy_model
model_lora1 = create_lora(
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
@@ -199,7 +199,9 @@ def test_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
def test_lora_lru_cache_model_manager(
default_vllm_config, dist_init, dummy_model, device
):
model = dummy_model
model_lora1 = create_lora(
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
@@ -289,7 +291,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
def test_lru_lora_model_manager(default_vllm_config, dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
model = dummy_model
@@ -415,7 +417,9 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_path):
def test_lru_cache_worker_adapter_manager(
default_vllm_config, dist_init, dummy_model, device, tmp_path
):
lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
)
@@ -529,7 +533,9 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
@pytest.mark.parametrize("device", DEVICES)
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path):
def test_worker_adapter_manager(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
@@ -636,7 +642,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
@pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model_lora = create_packed_lora(
1,

View File

@@ -55,7 +55,7 @@ def test_get_draft_quant_config_without_draft_model():
@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
def test_fc_layer_quant_config_usage(dist_init, device) -> None:
def test_fc_layer_quant_config_usage(default_vllm_config, dist_init, device) -> None:
import torch
from vllm.model_executor.layers.linear import ReplicatedLinear

View File

@@ -73,7 +73,9 @@ def run_intern_vit_test(
],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
def test_models(
default_vllm_config, dist_init, image_assets, model_id, dtype: str
) -> None:
run_intern_vit_test(
image_assets,
model_id,

View File

@@ -92,7 +92,9 @@ def run_radio_test(
],
)
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
def test_radio(dist_init, image_assets, model_id, dtype: str) -> None:
def test_radio(
default_vllm_config, dist_init, image_assets, model_id, dtype: str
) -> None:
run_radio_test(
image_assets,
model_id,

View File

@@ -145,18 +145,18 @@ def initialize_dummy_model(
model_config: ModelConfig,
):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=1)
current_device = torch.get_default_device()
vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config):
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=1)
with set_default_torch_dtype(model_config.dtype):
torch.set_default_device(current_platform.device_type)
model = model_cls(vllm_config=vllm_config)

View File

@@ -31,7 +31,7 @@ def test_platform_plugins():
)
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
def test_oot_custom_op(default_vllm_config, monkeypatch: pytest.MonkeyPatch):
# simulate workload by running an example
load_general_plugins()
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding

View File

@@ -277,6 +277,7 @@ def test_scaled_fp8_quant(dtype) -> None:
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading(
default_vllm_config,
method_cls,
is_checkpoint_fp8_serialized,
weight_block_size,

View File

@@ -721,13 +721,34 @@ def init_test_distributed_environment(
distributed_init_port: str,
local_rank: int = -1,
) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
# Note: This function is often called from Ray worker processes, so we
# can't rely on pytest fixtures to set the config. We check if the config
# is already set and only create a default one if needed.
from vllm.config import (
VllmConfig,
get_current_vllm_config_or_none,
set_current_vllm_config,
)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
if get_current_vllm_config_or_none() is not None:
# Config already set, use it directly
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
else:
# No config set, create a default one for the test
with set_current_vllm_config(VllmConfig()):
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank,
)
ensure_model_parallel_initialized(tp_size, pp_size)

View File

@@ -556,7 +556,7 @@ def _test_backend_correctness(
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness(
batch_spec_name: str, model: str, tensor_parallel_size: int
default_vllm_config, batch_spec_name: str, model: str, tensor_parallel_size: int
):
"""Test backend's correctness with causal attention."""

View File

@@ -79,7 +79,12 @@ from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
],
)
def test_mamba_layers_get_attn_backend(
dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type
default_vllm_config,
dist_init,
layer_class,
init_kwargs,
expected_backend,
expected_mamba_type,
):
"""Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs)

View File

@@ -394,7 +394,11 @@ def run_attention_backend(
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
def test_backend_correctness(
dist_init, batch_spec_name: str, model: str, tensor_parallel_size: int
default_vllm_config,
dist_init,
batch_spec_name: str,
model: str,
tensor_parallel_size: int,
):
"""
Test that all backends produce similar outputs to a reference implementation

View File

@@ -124,7 +124,12 @@ def _quantize_dequantize_fp8_ds_mla(
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
default_vllm_config,
dist_init,
batch_name,
kv_cache_dtype,
tensor_parallel_size,
workspace_init,
):
if current_platform.is_rocm():
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")

View File

@@ -21,7 +21,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
def test_rms_norm_batch_invariant_vs_standard(
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
default_vllm_config,
batch_size: int,
hidden_size: int,
dtype: torch.dtype,
eps: float,
):
"""
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
@@ -68,7 +72,9 @@ def test_rms_norm_batch_invariant_vs_standard(
@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("seq_len", [1, 32, 512])
@pytest.mark.parametrize("hidden_size", [2048, 4096])
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
def test_rms_norm_3d_input(
default_vllm_config, batch_size: int, seq_len: int, hidden_size: int
):
"""
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
@@ -107,7 +113,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
@skip_unsupported
def test_rms_norm_numerical_stability():
def test_rms_norm_numerical_stability(default_vllm_config):
"""
Test RMS norm numerical stability with extreme values.
@@ -167,7 +173,7 @@ def test_rms_norm_numerical_stability():
@skip_unsupported
def test_rms_norm_formula():
def test_rms_norm_formula(default_vllm_config):
"""
Test that RMS norm follows the correct mathematical formula.
@@ -201,7 +207,7 @@ def test_rms_norm_formula():
@skip_unsupported
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
def test_rms_norm_different_hidden_sizes(hidden_size: int):
def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
"""
Test RMS norm with various hidden sizes to ensure block size handling.
@@ -238,7 +244,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
@skip_unsupported
def test_rms_norm_determinism():
def test_rms_norm_determinism(default_vllm_config):
"""
Test that batch-invariant RMS norm produces deterministic results.

View File

@@ -299,6 +299,7 @@ def test_prompt_less_than_block_size():
)
def test_kv_transfer_handshake(dist_init):
"""Unit test for basic NixlConnector interface functionality."""
from vllm.config import set_current_vllm_config
# Test setup, we creates a scheduler that contains a NixlConnector
# of role SCHEDULER, and expect it to be serving NixlAgentMetadata from
@@ -308,81 +309,82 @@ def test_kv_transfer_handshake(dist_init):
vllm_config.kv_transfer_config.kv_buffer_device = "cpu"
scheduler = create_scheduler(vllm_config)
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
with set_current_vllm_config(vllm_config):
# Create two NixlConnector of role WORKER, one is the worker of
# the scheduler (prefill), the other is a worker of decode instance.
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
# Prefill connector will register KV cache to populate proper handshake
# metadata.
prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
prefill_connector.register_kv_caches(kv_caches)
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[0] == expected_agent_metadata
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
# Simulate EngineCore initialization that would gather connector
# metadata from all workers
metadata = prefill_connector.get_handshake_metadata()
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
# metadata is a NixlHandshakePayload, decode it to get NixlAgentMetadata
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
expected_agent_metadata = decoder.decode(metadata.agent_metadata_bytes)
# The scheduler connector expects metadata to be in
# dict[int, KVConnectorHandshakeMetadata], where the first key is
# the dp_rank, the second key is the tp_rank.
scheduler_connector = scheduler.get_kv_connector()
scheduler_connector.set_xfer_handshake_metadata({0: metadata})
# Simulate a request that finishes prefill, which returns
# corresponding NixlConnectorMetadata for decode instance.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished(
request, [0, 1, 2]
)
assert delay
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
# to validate the metadata received is the same as the one in prefill_connector.
with patch.object(
decode_connector.connector_worker, "add_remote_agent"
) as mock_add_remote_agent:
mock_add_remote_agent.return_type = "remote_agent"
decode_connector.connector_worker._nixl_handshake(
kv_connector_metadata["remote_host"],
kv_connector_metadata["remote_port"],
kv_connector_metadata["tp_size"],
kv_connector_metadata["remote_engine_id"],
)
received_metadata = mock_add_remote_agent.call_args.args
assert received_metadata[0] == expected_agent_metadata
assert received_metadata[1] == 0 # remote_tp_rank
assert received_metadata[2] == 1 # remote_tp_size
# Need to shutdown the background thread to release NIXL side channel port
scheduler_connector.shutdown()
class FakeNixlConnectorWorker(NixlConnectorWorker):
@@ -458,6 +460,7 @@ class TestNixlHandshake:
)
def test_multi_xfer_one_engine(
self,
default_vllm_config,
# dist_init is a fixture that initializes the distributed environment.
dist_init,
):
@@ -547,6 +550,7 @@ class TestNixlHandshake:
)
def test_async_load_kv(
self,
default_vllm_config,
# Fixture that initializes the distributed environment.
dist_init,
# Simulate consumer-producer TP sizes.
@@ -605,7 +609,7 @@ class TestNixlHandshake:
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
self, local_tp_size: int, default_vllm_config, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
@@ -670,7 +674,7 @@ class TestNixlHandshake:
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init
self, local_tp_size: int, default_vllm_config, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
@@ -770,6 +774,7 @@ class TestNixlHandshake:
)
def test_concurrent_load_kv(
self,
default_vllm_config,
# dist_init is a fixture that initializes the distributed environment.
dist_init,
):
@@ -830,7 +835,9 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
def test_handshake_fails_on_kv_cache_layout_mismatch(
self, default_vllm_config, dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
@@ -879,7 +886,7 @@ class TestNixlHandshake:
FakeNixlWrapper,
)
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
self, dist_init
self, default_vllm_config, dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
@@ -934,7 +941,7 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_kv_connector_stats(dist_init):
def test_kv_connector_stats(default_vllm_config, dist_init):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config = create_vllm_config()
@@ -1357,7 +1364,7 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
"TRITON_ATTN",
],
)
def test_register_kv_caches(dist_init, attn_backend):
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
@@ -1518,7 +1525,9 @@ class FakePlatform(Platform):
("oot", "VRAM"),
],
)
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type):
def test_kv_buffer_to_nixl_memory_types(
default_vllm_config, dist_init, kv_buffer_device, nixl_memory_type
):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
@@ -1563,7 +1572,7 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_shutdown_cleans_up_resources(dist_init):
def test_shutdown_cleans_up_resources(default_vllm_config, dist_init):
"""Test that shutdown() properly cleans up all resources."""
vllm_config = create_vllm_config()
@@ -1622,7 +1631,7 @@ def test_shutdown_cleans_up_resources(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_aborted_request_removed_from_worker_in_batch(dist_init):
def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_init):
"""
Create and schedule a request so that P adds it to in-batch tracking via
the real scheduler, then simulate an abort (request not in next scheduler
@@ -1731,7 +1740,7 @@ class FailingNixlWrapper(FakeNixlWrapper):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
def test_handshake_failure_returns_finished(dist_init):
def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
"""Test that handshake failures mark blocks invalid and return via get_finished."""
vllm_config = create_vllm_config()
@@ -1780,7 +1789,7 @@ def test_handshake_failure_returns_finished(dist_init):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FailingNixlWrapper,
)
def test_transfer_setup_failure_returns_finished(dist_init):
def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init):
"""Test that transfer setup failures mark blocks invalid
and return via get_finished."""
vllm_config = create_vllm_config()
@@ -1855,6 +1864,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
FakeNixlWrapper,
)
def test_compatibility_hash_validation(
default_vllm_config,
dist_init,
mismatch_type,
config_overrides,
@@ -1967,7 +1977,7 @@ def test_compatibility_hash_validation(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_decode_errors(dist_init, error_scenario):
def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario):
"""
Test that msgspec decode errors are properly handled during handshake.

View File

@@ -50,6 +50,7 @@ NUM_MAPPINGS = [3]
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_transfer(
default_vllm_config,
gpu_to_cpu: bool,
num_mappings: int,
head_size: int,

View File

@@ -112,15 +112,16 @@ def get_vllm_config():
@pytest.fixture
def model_runner():
vllm_config = get_vllm_config()
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1
)
runner = GPUModelRunner(vllm_config, DEVICE)
initialize_kv_cache(runner)
return runner
with set_current_vllm_config(vllm_config):
model_config = vllm_config.model_config
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
num_heads, head_size, 0.1
)
runner = GPUModelRunner(vllm_config, DEVICE)
initialize_kv_cache(runner)
yield runner
model_runner_2 = model_runner
@@ -546,7 +547,7 @@ def test_reload_weights_before_load_model(model_runner):
model_runner.reload_weights()
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
@@ -573,7 +574,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
@@ -600,7 +601,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert fwd_context is not None
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_with_kv_sharing_target_same_as_current(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
@@ -627,7 +628,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
assert fwd_context is not None
def test_init_kv_cache_without_kv_sharing():
def test_init_kv_cache_without_kv_sharing(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
@@ -694,7 +695,7 @@ def test_init_kv_cache_without_kv_sharing():
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_init_kv_cache_with_kv_sharing_valid():
def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
@@ -1047,7 +1048,7 @@ def test_input_batch_with_kernel_block_sizes():
assert block_table.block_size == kernel_size
def test_hybrid_cache_integration(model_runner, dist_init):
def test_hybrid_cache_integration(default_vllm_config, dist_init):
"""Test hybrid cache architecture integration with GPUModelRunner."""
# Create a new model runner with hybrid cache configuration
vllm_config = get_vllm_config()

View File

@@ -6,14 +6,14 @@ import torch
from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache():
def test_bind_kv_cache(default_vllm_config):
from vllm.attention.layer import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
"layers.1.self_attn": Attention(32, 128, 0.1),
"layers.2.self_attn": Attention(32, 128, 0.1),
"layers.3.self_attn": Attention(32, 128, 0.1),
"layers.0.self_attn": Attention(32, 128, 0.1, prefix="layers.0.self_attn"),
"layers.1.self_attn": Attention(32, 128, 0.1, prefix="layers.1.self_attn"),
"layers.2.self_attn": Attention(32, 128, 0.1, prefix="layers.2.self_attn"),
"layers.3.self_attn": Attention(32, 128, 0.1, prefix="layers.3.self_attn"),
}
kv_cache = {
"layers.0.self_attn": torch.zeros((1,)),
@@ -34,13 +34,13 @@ def test_bind_kv_cache():
assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
def test_bind_kv_cache_non_attention():
def test_bind_kv_cache_non_attention(default_vllm_config):
from vllm.attention.layer import Attention
# example from Jamba PP=2
ctx = {
"model.layers.20.attn": Attention(32, 128, 0.1),
"model.layers.28.attn": Attention(32, 128, 0.1),
"model.layers.20.attn": Attention(32, 128, 0.1, prefix="model.layers.20.attn"),
"model.layers.28.attn": Attention(32, 128, 0.1, prefix="model.layers.28.attn"),
}
kv_cache = {
"model.layers.20.attn": torch.zeros((1,)),

View File

@@ -59,10 +59,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
)
# 2. override if passed by environment or config
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config()
if vllm_config.attention_config.flash_attn_version is not None:
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.flash_attn_version is not None
):
fa_version = vllm_config.attention_config.flash_attn_version
# 3. fallback for unsupported combinations

View File

@@ -42,6 +42,7 @@ from vllm.config.vllm import (
VllmConfig,
get_cached_compilation_config,
get_current_vllm_config,
get_current_vllm_config_or_none,
get_layers_from_vllm_config,
set_current_vllm_config,
)
@@ -105,6 +106,7 @@ __all__ = [
"VllmConfig",
"get_cached_compilation_config",
"get_current_vllm_config",
"get_current_vllm_config_or_none",
"set_current_vllm_config",
"get_layers_from_vllm_config",
]

View File

@@ -1441,13 +1441,18 @@ def get_cached_compilation_config():
def get_current_vllm_config() -> VllmConfig:
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
# Use stack level 2 so the log contains the line of the caller,
# so it's easier to track down the source of the warning.
logger.warning("Current vLLM config is not set.", stacklevel=2)
return VllmConfig()
raise AssertionError(
"Current vLLM config is not set. This typically means "
"get_current_vllm_config() was called outside of a "
"set_current_vllm_config() context, or a CustomOp was instantiated "
"at module import time or model forward time when config is not set. "
"For tests that directly test custom ops/modules, use the "
"'default_vllm_config' pytest fixture from tests/conftest.py."
)
return _current_vllm_config
def get_current_vllm_config_or_none() -> VllmConfig | None:
return _current_vllm_config

View File

@@ -117,9 +117,9 @@ class DeviceCommunicatorBase:
use_ep = False
all2all_backend = None
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config()
config = get_current_vllm_config_or_none()
if config is not None:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),

View File

@@ -9,7 +9,7 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -184,7 +184,7 @@ class QuickAllReduce:
)
return
self.qr_quant_level = QuickReduceRegime[regime_str]
vllm_config = get_current_vllm_config()
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and hasattr(vllm_config, "model_config")

View File

@@ -1177,9 +1177,9 @@ def init_distributed_environment(
distributed_init_method,
backend,
)
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config()
config = get_current_vllm_config_or_none()
if (
config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher"
@@ -1251,7 +1251,7 @@ def init_distributed_environment(
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
if config.parallel_config.nnodes > 1:
if config is not None and config.parallel_config.nnodes > 1:
_NODE_COUNT = config.parallel_config.nnodes
else:
_NODE_COUNT = _node_count(_WORLD.cpu_group)
@@ -1260,7 +1260,7 @@ def init_distributed_environment(
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size"
)
if config.parallel_config.nnodes_within_dp > 1:
if config is not None and config.parallel_config.nnodes_within_dp > 1:
if parallel_config.data_parallel_size > 1:
world_size_inner_dp = parallel_config.world_size
group_ranks = [
@@ -1316,9 +1316,9 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config()
config = get_current_vllm_config_or_none()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size

View File

@@ -13,10 +13,28 @@ from vllm.model_executor.layers.quantization.utils.layer_utils import replace_pa
from vllm.utils.torch_utils import direct_register_custom_op
_CPU_MOE_LAYER_CACHE = {}
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
class _LazyActivationDict(dict):
"""Lazily instantiate activation functions on first access.
Avoids triggering CustomOp.__init__() at module import time,
which would call get_current_vllm_config() before config is set.
"""
_factories: dict[str, type[SiluAndMul] | type[SwigluOAIAndMul]] = {
"silu": SiluAndMul,
"swigluoai": SwigluOAIAndMul,
}
def __missing__(self, key: str) -> SiluAndMul | SwigluOAIAndMul:
if key not in self._factories:
raise KeyError(f"{key} is not a supported activation")
self[key] = self._factories[key]()
return self[key]
_CPU_MOE_ACT = _LazyActivationDict()
def grouped_topk(
@@ -212,7 +230,7 @@ class CPUFusedMOE:
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT, f"{activation} is not supported."
assert activation in _CPU_MOE_ACT._factories, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(

View File

@@ -540,6 +540,20 @@ class FusedMoE(CustomOp):
self.apply_router_weight_on_input = apply_router_weight_on_input
self.activation = activation
self._grouped_topk_impl: GroupedTopk | None = None
if self.use_grouped_topk:
assert self.num_expert_group is not None
assert self.topk_group is not None
self._grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError(
"Only softmax scoring function is supported for non-grouped topk."
@@ -1588,19 +1602,8 @@ class FusedMoE(CustomOp):
# DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk and valid_grouping():
assert self.topk_group is not None
assert self.num_expert_group is not None
grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
num_fused_shared_experts=self.num_fused_shared_experts,
)
topk_weights, topk_ids = grouped_topk_impl(
assert self._grouped_topk_impl is not None
topk_weights, topk_ids = self._grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
e_score_correction_bias=self.e_score_correction_bias,

View File

@@ -339,15 +339,11 @@ def apply_rotary_pos_emb_flashatt(
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
apply_rotary_emb: ApplyRotaryEmb,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
q_embed = apply_rotary_emb(q, cos, sin)
k_embed = apply_rotary_emb(k, cos, sin)
@@ -410,6 +406,11 @@ class KeyeSiglipAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward(
self,
hidden_states: torch.Tensor,
@@ -448,7 +449,7 @@ class KeyeSiglipAttention(nn.Module):
self.num_kv_heads,
self.head_dim,
)
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin)
q, k = apply_rotary_pos_emb_flashatt(q, k, cos, sin, self.apply_rotary_emb)
v = v.view(
*v.shape[:-1],
self.num_kv_heads,

View File

@@ -152,16 +152,12 @@ def apply_rotary_pos_emb(
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_flash_attn_backend: bool = False,
is_flash_attn_backend: bool,
apply_rotary_emb: ApplyRotaryEmb,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
elif is_flash_attn_backend and current_platform.is_rocm():
@@ -235,6 +231,11 @@ class Siglip2Attention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward(
self,
hidden_states: torch.Tensor,
@@ -260,6 +261,7 @@ class Siglip2Attention(nn.Module):
cos,
sin,
self.attn.is_flash_attn_backend,
self.apply_rotary_emb,
)
queries = queries.squeeze(0)
keys = keys.squeeze(0)

View File

@@ -14,7 +14,7 @@ import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
from vllm.config.compilation import CompilationMode
from vllm.distributed import (
ensure_model_parallel_initialized,
@@ -268,7 +268,9 @@ class Worker(WorkerBase):
# to hijack tensor allocation.
def load_model(self) -> None:
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
with self._maybe_get_memory_pool_context(tag="weights"):
with self._maybe_get_memory_pool_context(
tag="weights"
) and set_current_vllm_config(self.vllm_config):
self.model_runner.load_model(eep_scale_up=eep_scale_up)
def update_config(self, overrides: dict[str, Any]) -> None: