[Misc] Fix Current vLLM config is not set. warnings, assert to avoid issues in the future (#31747)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Lucas Wilkinson
2026-01-08 18:20:49 -05:00
committed by GitHub
parent 5d3b6097ad
commit 6cdf015c3c
48 changed files with 380 additions and 240 deletions

View File

@@ -15,6 +15,7 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.distributed import (
tensor_model_parallel_all_gather,
@@ -340,38 +341,42 @@ def async_tp_pass_on_test_model(
)
async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)
assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
# Set the global vllm_config for TestBackend which calls
# get_current_vllm_config()
with set_current_vllm_config(vllm_config):
backend = TestBackend(async_tp_pass)
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
hidden_states = torch.randn(
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
)
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
hidden_states = torch.randn(
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
assert async_tp_pass.matched_count == 1
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
assert async_tp_pass.matched_count == 1
# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist
backend.check_after_ops(model.ops_in_model_after())
# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist
backend.check_after_ops(model.ops_in_model_after())
@create_new_process_for_each_test()