|
|
|
|
@@ -34,13 +34,16 @@ class SimpleMLP(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_vllm_config(
|
|
|
|
|
compilation_config: CompilationConfig, max_num_seqs: int = 8
|
|
|
|
|
compilation_config: CompilationConfig,
|
|
|
|
|
max_num_seqs: int = 8,
|
|
|
|
|
lora_config: bool = False,
|
|
|
|
|
) -> MagicMock:
|
|
|
|
|
mock_config = MagicMock(spec=VllmConfig)
|
|
|
|
|
mock_config.compilation_config = compilation_config
|
|
|
|
|
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
|
|
|
|
mock_config.parallel_config = ParallelConfig()
|
|
|
|
|
|
|
|
|
|
if not lora_config:
|
|
|
|
|
mock_config.lora_config = None
|
|
|
|
|
# Mimic the behavior of VllmConfig.__post_init__()
|
|
|
|
|
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
|
|
|
|
compilation_config.set_splitting_ops_for_v1()
|
|
|
|
|
@@ -50,19 +53,21 @@ def _create_vllm_config(
|
|
|
|
|
|
|
|
|
|
class TestCudagraphDispatcher:
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"case_id,cudagraph_mode_str,compilation_mode",
|
|
|
|
|
"cudagraph_mode_str,compilation_mode,lora_config",
|
|
|
|
|
[
|
|
|
|
|
# Test case 0: Full CG for mixed batches, no separate routine
|
|
|
|
|
(0, "FULL", CompilationMode.NONE),
|
|
|
|
|
("FULL", CompilationMode.NONE, False),
|
|
|
|
|
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
|
|
|
|
(1, "FULL_AND_PIECEWISE", CompilationMode.NONE),
|
|
|
|
|
("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
|
|
|
|
# Test case 2: Full CG for uniform batches, no CG for mixed
|
|
|
|
|
(2, "FULL_DECODE_ONLY", CompilationMode.NONE),
|
|
|
|
|
("FULL_DECODE_ONLY", CompilationMode.NONE, False),
|
|
|
|
|
# Test case 3: PIECEWISE for all
|
|
|
|
|
(3, "PIECEWISE", CompilationMode.VLLM_COMPILE),
|
|
|
|
|
("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
|
|
|
|
|
# Test case 4: PIECEWISE for all, specialize LoRA cases
|
|
|
|
|
("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
def test_dispatcher(self, cudagraph_mode_str, compilation_mode):
|
|
|
|
|
def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
|
|
|
|
|
# Setup dispatcher
|
|
|
|
|
comp_config = CompilationConfig(
|
|
|
|
|
cudagraph_mode=cudagraph_mode_str,
|
|
|
|
|
@@ -70,7 +75,17 @@ class TestCudagraphDispatcher:
|
|
|
|
|
cudagraph_capture_sizes=[1, 8],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
|
|
|
|
config = _create_vllm_config(
|
|
|
|
|
comp_config, max_num_seqs=8, lora_config=lora_config
|
|
|
|
|
)
|
|
|
|
|
if (
|
|
|
|
|
cudagraph_mode_str == "FULL_AND_PIECEWISE"
|
|
|
|
|
and compilation_mode == CompilationMode.NONE
|
|
|
|
|
):
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
|
dispatcher = CudagraphDispatcher(config)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
dispatcher = CudagraphDispatcher(config)
|
|
|
|
|
dispatcher.initialize_cudagraph_keys(
|
|
|
|
|
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
|
|
|
|
|
@@ -78,17 +93,24 @@ class TestCudagraphDispatcher:
|
|
|
|
|
|
|
|
|
|
# Verify the key is initialized correctly
|
|
|
|
|
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
|
|
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
|
|
|
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
|
|
|
|
|
4 if lora_config else 2
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
|
|
|
|
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
|
|
|
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
|
|
|
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
|
|
|
|
|
4 if lora_config else 2
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
|
|
|
|
|
|
|
|
|
# Test dispatch logic
|
|
|
|
|
# 1. non-uniform batch, size in cudagraph size list
|
|
|
|
|
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
|
|
|
|
desc_full_exact = BatchDescriptor(
|
|
|
|
|
num_tokens=8,
|
|
|
|
|
uniform_decode=False,
|
|
|
|
|
)
|
|
|
|
|
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
|
|
|
|
if cudagraph_mode_str == "FULL":
|
|
|
|
|
assert rt_mode == CUDAGraphMode.FULL
|
|
|
|
|
@@ -138,7 +160,6 @@ class TestCUDAGraphWrapper:
|
|
|
|
|
self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
|
|
|
|
|
self.input_tensor = torch.randn(1, 10, device="cuda")
|
|
|
|
|
|
|
|
|
|
@create_new_process_for_each_test("spawn")
|
|
|
|
|
def test_capture_and_replay(self):
|
|
|
|
|
wrapper = CUDAGraphWrapper(
|
|
|
|
|
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
|
|
|
|
@@ -192,7 +213,6 @@ class TestCUDAGraphWrapper:
|
|
|
|
|
eager_output = self.model(self.input_tensor)
|
|
|
|
|
torch.testing.assert_close(eager_output, output2)
|
|
|
|
|
|
|
|
|
|
@create_new_process_for_each_test("spawn")
|
|
|
|
|
def test_bypass_on_mode_mismatch(self):
|
|
|
|
|
wrapper = CUDAGraphWrapper(
|
|
|
|
|
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
|
|
|
|
@@ -216,7 +236,6 @@ class TestCUDAGraphWrapper:
|
|
|
|
|
mock_forward.assert_called_once()
|
|
|
|
|
assert not wrapper.concrete_cudagraph_entries
|
|
|
|
|
|
|
|
|
|
@create_new_process_for_each_test("spawn")
|
|
|
|
|
def test_bypass_on_mode_none(self):
|
|
|
|
|
wrapper = CUDAGraphWrapper(
|
|
|
|
|
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
|
|
|
|
|