[compile] raise on compile_size implicit padding (#32343)
Signed-off-by: dolpm <34420038+dolpm@users.noreply.github.com>
This commit is contained in:
@@ -470,3 +470,68 @@ def test_cached_compilation_config(default_vllm_config):
|
||||
|
||||
code = " ".join(code)
|
||||
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
|
||||
|
||||
|
||||
def test_compile_sizes_padding_validation():
|
||||
"""Test that compile_sizes with values that would be padded raises an error."""
|
||||
# cudagraph_capture_sizes=[1, 2, 4, 8] means:
|
||||
# - size 1 -> padded to 1
|
||||
# - size 2 -> padded to 2
|
||||
# - size 3 -> padded to 4
|
||||
# - size 4 -> padded to 4
|
||||
# - size 5 -> padded to 8
|
||||
# etc.
|
||||
# So compile_sizes=[3] should fail because 3 would be padded to 4
|
||||
|
||||
with pytest.raises(ValueError, match="would be padded to"):
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[3],
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
|
||||
with pytest.raises(ValueError, match="would be padded to"):
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[5],
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=[1, 2, 4, 8],
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
|
||||
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
compile_sizes=["cudagraph_capture_sizes"],
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
|
||||
|
||||
# When cudagraphs are disabled (max_cudagraph_capture_size=0),
|
||||
# padding validation should be skipped
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[],
|
||||
max_cudagraph_capture_size=0,
|
||||
compile_sizes=[3, 5, 7], # would be invalid with cudagraphs
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [3, 5, 7]
|
||||
|
||||
# When cudagraph_mode is NONE but capture_sizes is non-empty,
|
||||
# padding validation should still be skipped
|
||||
config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||
max_cudagraph_capture_size=8,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
compile_sizes=[3, 5, 7], # would be invalid if cudagraphs were enabled
|
||||
)
|
||||
config.post_init_cudagraph_sizes()
|
||||
assert sorted(config.compile_sizes) == [3, 5, 7]
|
||||
|
||||
Reference in New Issue
Block a user