[compile] raise on compile_size implicit padding (#32343)
Signed-off-by: dolpm <34420038+dolpm@users.noreply.github.com>
This commit is contained in:
@@ -470,3 +470,68 @@ def test_cached_compilation_config(default_vllm_config):
|
||||
|
||||
code = " ".join(code)
|
||||
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
|
||||
|
||||
|
||||
def test_compile_sizes_padding_validation():
|
||||
"""Test that compile_sizes with values that would be padded raises an error."""
|
||||
# cudagraph_capture_sizes=[1, 2, 4, 8] means:
|
||||
# - size 1 -> padded to 1
|
||||
# - size 2 -> padded to 2
|
||||
# - size 3 -> padded to 4
|
||||
# - size 4 -> padded to 4
|
||||
# - size 5 -> padded to 8
|
||||
# etc.
|
||||
# So compile_sizes=[3] should fail because 3 would be padded to 4
|
||||
|
||||
with pytest.raises(ValueError, match="would be padded to"):
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[3],
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
|
||||
with pytest.raises(ValueError, match="would be padded to"):
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[5],
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[1, 2, 4, 8],
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
|
||||
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=["cudagraph_capture_sizes"],
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
|
||||
|
||||
# When cudagraphs are disabled (max_cudagraph_capture_size=0),
|
||||
# padding validation should be skipped
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[],
|
||||
max_cudagraph_capture_size=0,
|
||||
compile_sizes=[3, 5, 7], # would be invalid with cudagraphs
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [3, 5, 7]
|
||||
|
||||
# When cudagraph_mode is NONE but capture_sizes is non-empty,
|
||||
# padding validation should still be skipped
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
compile_sizes=[3, 5, 7], # would be invalid if cudagraphs were enabled
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [3, 5, 7]
|
||||
|
||||
@@ -909,6 +909,20 @@ class CompilationConfig:
|
||||
# May get recomputed in the model runner if adjustment is needed for spec-decode
|
||||
self.compute_bs_to_padded_graph_size()
|
||||
|
||||
# Validate that compile_sizes won't be changed by padding.
|
||||
# Only validate when cudagraphs are actually being used.
|
||||
if self.compile_sizes and self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
for size in self.compile_sizes:
|
||||
if size <= self.max_cudagraph_capture_size:
|
||||
padded = self.bs_to_padded_graph_size[size]
|
||||
if padded != size:
|
||||
raise ValueError(
|
||||
f"compile_sizes contains {size} which would be "
|
||||
f"padded to {padded}. All compile_sizes must be "
|
||||
"values that won't be changed by cudagraph padding. "
|
||||
"Use values from cudagraph_capture_sizes."
|
||||
)
|
||||
|
||||
def set_splitting_ops_for_v1(
|
||||
self, all2all_backend: str, data_parallel_size: int = 1
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user