[Misc] Remove pad_for_cudagraphs from config (#30143)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -2,14 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CUDAGraphMode,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.compilation import CompilationMode, PassConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
@@ -17,6 +23,7 @@ from vllm.utils.torch_utils import (
|
||||
_is_torch_equal_or_newer,
|
||||
is_torch_equal,
|
||||
)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
@@ -472,6 +479,19 @@ def test_cached_compilation_config(default_vllm_config):
|
||||
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
|
||||
|
||||
|
||||
def _create_vllm_config_for_validation(
|
||||
compilation_config: CompilationConfig,
|
||||
) -> MagicMock:
|
||||
"""Helper to create a mock VllmConfig for padding validation testing."""
|
||||
mock_config = MagicMock(spec=VllmConfig)
|
||||
mock_config.compilation_config = compilation_config
|
||||
mock_config.scheduler_config = SchedulerConfig.default_factory(max_num_seqs=8)
|
||||
mock_config.parallel_config = ParallelConfig()
|
||||
mock_config.speculative_config = None
|
||||
mock_config.lora_config = None
|
||||
return mock_config
|
||||
|
||||
|
||||
def test_compile_sizes_padding_validation():
|
||||
"""Test that compile_sizes with values that would be padded raises an error."""
|
||||
# cudagraph_capture_sizes=[1, 2, 4, 8] means:
|
||||
@@ -488,29 +508,39 @@ def test_compile_sizes_padding_validation():
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[3],
|
||||
cudagraph_mode=CUDAGraphMode.FULL,
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
|
||||
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
|
||||
|
||||
with pytest.raises(ValueError, match="would be padded to"):
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[5],
|
||||
cudagraph_mode=CUDAGraphMode.FULL,
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
|
||||
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
|
||||
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[1, 2, 4, 8],
|
||||
cudagraph_mode=CUDAGraphMode.FULL,
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
|
||||
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
|
||||
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) # Should not raise
|
||||
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=["cudagraph_capture_sizes"],
|
||||
cudagraph_mode=CUDAGraphMode.FULL,
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
|
||||
@@ -535,3 +565,5 @@ def test_compile_sizes_padding_validation():
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [3, 5, 7]
|
||||
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
|
||||
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
|
||||
|
||||
Reference in New Issue
Block a user