[Misc] Fix Current vLLM config is not set. warnings, assert to avoid issues in the future (#31747)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -340,38 +341,42 @@ def async_tp_pass_on_test_model(
|
||||
)
|
||||
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
assert (
|
||||
async_tp_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
async_tp_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
# Set the global vllm_config for TestBackend which calls
|
||||
# get_current_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||
assert (
|
||||
async_tp_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
async_tp_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
|
||||
hidden_states = torch.randn(
|
||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||
)
|
||||
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
hidden_states = torch.randn(
|
||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||
)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
|
||||
assert async_tp_pass.matched_count == 1
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
assert async_tp_pass.matched_count == 1
|
||||
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
|
||||
Reference in New Issue
Block a user