Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,8 +9,14 @@ import torch.nn as nn
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
ParallelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationLevel,
|
||||
CUDAGraphMode,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
@@ -18,7 +24,6 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
|
||||
# Helper MLP for testing
|
||||
class SimpleMLP(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(10, 10)
|
||||
@@ -28,8 +33,9 @@ class SimpleMLP(nn.Module):
|
||||
return self.fc2(self.fc1(x))
|
||||
|
||||
|
||||
def _create_vllm_config(compilation_config: CompilationConfig,
|
||||
max_num_seqs: int = 8) -> MagicMock:
|
||||
def _create_vllm_config(
|
||||
compilation_config: CompilationConfig, max_num_seqs: int = 8
|
||||
) -> MagicMock:
|
||||
mock_config = MagicMock(spec=VllmConfig)
|
||||
mock_config.compilation_config = compilation_config
|
||||
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
||||
@@ -43,7 +49,6 @@ def _create_vllm_config(compilation_config: CompilationConfig,
|
||||
|
||||
|
||||
class TestCudagraphDispatcher:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case_id,cudagraph_mode_str,compilation_level",
|
||||
[
|
||||
@@ -55,18 +60,21 @@ class TestCudagraphDispatcher:
|
||||
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
|
||||
# Test case 3: Piecewise for all
|
||||
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
|
||||
])
|
||||
],
|
||||
)
|
||||
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
|
||||
# Setup dispatcher
|
||||
comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str,
|
||||
level=compilation_level,
|
||||
cudagraph_capture_sizes=[1, 8])
|
||||
comp_config = CompilationConfig(
|
||||
cudagraph_mode=cudagraph_mode_str,
|
||||
level=compilation_level,
|
||||
cudagraph_capture_sizes=[1, 8],
|
||||
)
|
||||
|
||||
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
||||
dispatcher = CudagraphDispatcher(config)
|
||||
dispatcher.initialize_cudagraph_keys(
|
||||
cudagraph_mode=comp_config.cudagraph_mode,
|
||||
uniform_decode_query_len=1)
|
||||
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
|
||||
)
|
||||
|
||||
# Verify the key is initialized correctly
|
||||
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
@@ -114,8 +122,7 @@ class TestCudagraphDispatcher:
|
||||
|
||||
# 4. Cascade attention should have a fall back mode
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact,
|
||||
use_cascade_attn=True)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
|
||||
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_full_exact.non_uniform
|
||||
@@ -125,7 +132,6 @@ class TestCudagraphDispatcher:
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
class TestCUDAGraphWrapper:
|
||||
|
||||
def setup_method(self):
|
||||
self.vllm_config = _create_vllm_config(CompilationConfig())
|
||||
self.model = SimpleMLP().to("cuda")
|
||||
@@ -134,26 +140,30 @@ class TestCUDAGraphWrapper:
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_capture_and_replay(self):
|
||||
wrapper = CUDAGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
wrapper = CUDAGraphWrapper(
|
||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||
|
||||
# 0. global warmup
|
||||
with set_forward_context(attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None):
|
||||
with set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
):
|
||||
wrapper(self.input_tensor)
|
||||
|
||||
# 1. Capture
|
||||
with set_forward_context(
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
batch_descriptor=batch_descriptor),\
|
||||
patch("torch.cuda.graph",
|
||||
wraps=torch.cuda.graph) as mock_cuda_graph:
|
||||
batch_descriptor=batch_descriptor,
|
||||
),
|
||||
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
|
||||
):
|
||||
output1 = wrapper(self.input_tensor)
|
||||
# capturing phase should generate a zero output
|
||||
assert torch.allclose(output1, torch.zeros_like(output1))
|
||||
@@ -164,13 +174,17 @@ class TestCUDAGraphWrapper:
|
||||
assert entry.cudagraph is not None
|
||||
|
||||
# 2. Replay
|
||||
with set_forward_context(
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
batch_descriptor=batch_descriptor),\
|
||||
patch.object(entry.cudagraph, 'replay',
|
||||
wraps=entry.cudagraph.replay) as mock_replay:
|
||||
batch_descriptor=batch_descriptor,
|
||||
),
|
||||
patch.object(
|
||||
entry.cudagraph, "replay", wraps=entry.cudagraph.replay
|
||||
) as mock_replay,
|
||||
):
|
||||
output2 = wrapper(self.input_tensor)
|
||||
mock_replay.assert_called_once()
|
||||
|
||||
@@ -180,20 +194,23 @@ class TestCUDAGraphWrapper:
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_bypass_on_mode_mismatch(self):
|
||||
wrapper = CUDAGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
wrapper = CUDAGraphWrapper(
|
||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||
|
||||
with set_forward_context(
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=batch_descriptor), \
|
||||
patch('torch.cuda.graph',
|
||||
wraps=torch.cuda.graph) as mock_cuda_graph, \
|
||||
patch.object(self.model, 'forward',
|
||||
wraps=self.model.forward) as mock_forward:
|
||||
batch_descriptor=batch_descriptor,
|
||||
),
|
||||
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
|
||||
patch.object(
|
||||
self.model, "forward", wraps=self.model.forward
|
||||
) as mock_forward,
|
||||
):
|
||||
wrapper(self.input_tensor)
|
||||
mock_cuda_graph.assert_not_called()
|
||||
mock_forward.assert_called_once()
|
||||
@@ -201,18 +218,20 @@ class TestCUDAGraphWrapper:
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_bypass_on_mode_none(self):
|
||||
wrapper = CUDAGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
wrapper = CUDAGraphWrapper(
|
||||
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=10)
|
||||
|
||||
with set_forward_context(
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=batch_descriptor), \
|
||||
patch('torch.cuda.graph',
|
||||
wraps=torch.cuda.graph) as mock_cuda_graph:
|
||||
batch_descriptor=batch_descriptor,
|
||||
),
|
||||
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
|
||||
):
|
||||
wrapper(self.input_tensor)
|
||||
mock_cuda_graph.assert_not_called()
|
||||
assert not wrapper.concrete_cudagraph_entries
|
||||
@@ -220,38 +239,44 @@ class TestCUDAGraphWrapper:
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
class TestCudagraphIntegration:
|
||||
|
||||
def setup_method(self):
|
||||
# only FULL mode for non-uniform batches
|
||||
self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
cudagraph_mode="FULL",
|
||||
cudagraph_capture_sizes=[10, 20])
|
||||
self.comp_config = CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
cudagraph_mode="FULL",
|
||||
cudagraph_capture_sizes=[10, 20],
|
||||
)
|
||||
self.vllm_config = _create_vllm_config(self.comp_config)
|
||||
self.dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
self.dispatcher.initialize_cudagraph_keys(
|
||||
self.comp_config.cudagraph_mode, uniform_decode_query_len=1)
|
||||
self.comp_config.cudagraph_mode, uniform_decode_query_len=1
|
||||
)
|
||||
|
||||
def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode,
|
||||
batch_descriptor):
|
||||
def _run_and_monitor_call(
|
||||
self, wrapper, input_tensor, runtime_mode, batch_descriptor
|
||||
):
|
||||
"""Helper to run a single call and monitor the action."""
|
||||
|
||||
with patch('torch.cuda.graph',
|
||||
wraps=torch.cuda.graph) as mock_graph_context, \
|
||||
patch.object(wrapper, 'runnable',
|
||||
wraps=wrapper.runnable) as mock_runnable:
|
||||
with (
|
||||
patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
|
||||
patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
|
||||
):
|
||||
entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
|
||||
|
||||
entry = wrapper.concrete_cudagraph_entries.get(
|
||||
batch_descriptor, None)
|
||||
|
||||
context = set_forward_context(attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=runtime_mode,
|
||||
batch_descriptor=batch_descriptor)
|
||||
context = set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
)
|
||||
mock_replay = MagicMock()
|
||||
if entry and entry.cudagraph:
|
||||
with context, \
|
||||
patch.object(entry.cudagraph, 'replay',
|
||||
new_callable=MagicMock) as mock_replay:
|
||||
with (
|
||||
context,
|
||||
patch.object(
|
||||
entry.cudagraph, "replay", new_callable=MagicMock
|
||||
) as mock_replay,
|
||||
):
|
||||
wrapper(input_tensor)
|
||||
else:
|
||||
with context:
|
||||
@@ -272,8 +297,7 @@ class TestCudagraphIntegration:
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_capture_replay_bypass_logic(self):
|
||||
model = SimpleMLP().to("cuda")
|
||||
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
|
||||
CUDAGraphMode.FULL)
|
||||
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
|
||||
max_bs = 16
|
||||
persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
|
||||
input_1 = persistent_input_buffer[:1]
|
||||
@@ -285,75 +309,79 @@ class TestCudagraphIntegration:
|
||||
desc_3_unseen = BatchDescriptor(num_tokens=3)
|
||||
|
||||
# 0. global warmup
|
||||
with set_forward_context(attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None):
|
||||
with set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
):
|
||||
full_wrapper(input_1)
|
||||
|
||||
rt_mode, key = self.dispatcher.dispatch(desc_1)
|
||||
# 1. Capture first shape
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
|
||||
key)
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
|
||||
assert action == "capture_global"
|
||||
|
||||
# 2. Replay first shape
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode,
|
||||
key)
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
|
||||
assert action == "replay"
|
||||
|
||||
rt_mode, key = self.dispatcher.dispatch(desc_2)
|
||||
# 3. Capture second shape
|
||||
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode,
|
||||
key)
|
||||
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
|
||||
assert action == "capture_global"
|
||||
|
||||
# 4. Replay second shape
|
||||
action = self._run_and_monitor_call(full_wrapper, input_2,
|
||||
CUDAGraphMode.FULL, desc_2)
|
||||
action = self._run_and_monitor_call(
|
||||
full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
|
||||
)
|
||||
assert action == "replay"
|
||||
|
||||
# 5. Bypass if no key match
|
||||
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode,
|
||||
key)
|
||||
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
|
||||
assert action == "bypass"
|
||||
|
||||
# capture unseen shape is not allowed after disable
|
||||
set_cudagraph_capturing_enabled(False)
|
||||
with pytest.raises(RuntimeError):
|
||||
self._run_and_monitor_call(full_wrapper, input_3,
|
||||
CUDAGraphMode.FULL, desc_3_unseen)
|
||||
self._run_and_monitor_call(
|
||||
full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
|
||||
)
|
||||
set_cudagraph_capturing_enabled(True)
|
||||
|
||||
@create_new_process_for_each_test("spawn")
|
||||
def test_nested_wrappers(self):
|
||||
"""Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
|
||||
model = SimpleMLP().to("cuda")
|
||||
full_wrapper = CUDAGraphWrapper(model, self.vllm_config,
|
||||
CUDAGraphMode.FULL)
|
||||
full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
|
||||
input_1 = torch.randn(1, 10, device="cuda")
|
||||
|
||||
# Setup: Inner model is wrapped with PIECEWISE, outer with FULL
|
||||
inner_model = SimpleMLP().to("cuda")
|
||||
piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config,
|
||||
CUDAGraphMode.PIECEWISE)
|
||||
piecewise_wrapper = CUDAGraphWrapper(
|
||||
inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
|
||||
)
|
||||
inner_model.forward = MagicMock(wraps=inner_model.forward)
|
||||
outer_model = SimpleMLP().to("cuda")
|
||||
# When outer model is called, it calls the piecewise_wrapper
|
||||
outer_model.forward = MagicMock(wraps=outer_model.forward,
|
||||
side_effect=piecewise_wrapper)
|
||||
full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config,
|
||||
CUDAGraphMode.FULL)
|
||||
outer_model.forward = MagicMock(
|
||||
wraps=outer_model.forward, side_effect=piecewise_wrapper
|
||||
)
|
||||
full_wrapper = CUDAGraphWrapper(
|
||||
outer_model, self.vllm_config, CUDAGraphMode.FULL
|
||||
)
|
||||
|
||||
desc_1 = BatchDescriptor(num_tokens=1)
|
||||
|
||||
# 0. global warmup
|
||||
with set_forward_context(attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None):
|
||||
with set_forward_context(
|
||||
attn_metadata=None,
|
||||
vllm_config=self.vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
batch_descriptor=None,
|
||||
):
|
||||
full_wrapper(input_1)
|
||||
|
||||
# --- Test runtime mode FULL---
|
||||
@@ -361,8 +389,9 @@ class TestCudagraphIntegration:
|
||||
# The inner mock should be called once inside the graph capture.
|
||||
outer_model.forward.reset_mock()
|
||||
inner_model.forward.reset_mock()
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||
CUDAGraphMode.FULL, desc_1)
|
||||
action = self._run_and_monitor_call(
|
||||
full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
|
||||
)
|
||||
assert action == "capture_global"
|
||||
assert outer_model.forward.call_count == 1
|
||||
assert inner_model.forward.call_count == 1
|
||||
@@ -370,8 +399,9 @@ class TestCudagraphIntegration:
|
||||
# Run again. Expect outer wrapper to replay.
|
||||
# The outer model should NOT be called because the whole graph
|
||||
# is replayed.
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||
CUDAGraphMode.FULL, desc_1)
|
||||
action = self._run_and_monitor_call(
|
||||
full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
|
||||
)
|
||||
assert action == "replay"
|
||||
assert outer_model.forward.call_count == 1 # No new call
|
||||
assert inner_model.forward.call_count == 1
|
||||
@@ -382,16 +412,18 @@ class TestCudagraphIntegration:
|
||||
# Run with PIECEWISE mode context.
|
||||
# Expect outer wrapper to bypass and call inner wrapper.
|
||||
# Inner wrapper should capture.
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||
CUDAGraphMode.PIECEWISE, desc_1)
|
||||
action = self._run_and_monitor_call(
|
||||
full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
|
||||
)
|
||||
assert action == "capture_global"
|
||||
assert outer_model.forward.call_count == 1
|
||||
assert inner_model.forward.call_count == 1
|
||||
|
||||
# Run again with PIECEWISE.
|
||||
# Outer bypasses, inner replays.
|
||||
action = self._run_and_monitor_call(full_wrapper, input_1,
|
||||
CUDAGraphMode.PIECEWISE, desc_1)
|
||||
action = self._run_and_monitor_call(
|
||||
full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
|
||||
)
|
||||
assert action == "bypass"
|
||||
assert outer_model.forward.call_count == 2
|
||||
assert inner_model.forward.call_count == 1
|
||||
|
||||
Reference in New Issue
Block a user