fix(security): Add VLLM_MAX_N_SEQUENCES environment variable and enforce limit (#37952)

Signed-off-by: jperezde <jperezde@redhat.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Juan Pérez de Algaba
2026-03-27 14:02:10 +01:00
committed by GitHub
parent 497e234d38
commit b111f8a61f
5 changed files with 193 additions and 0 deletions

View File

@@ -454,3 +454,55 @@ class TestVllmConfigureLogging:
with pytest.raises(ValueError, match="invalid literal for int"):
_ = envs.VLLM_CONFIGURE_LOGGING
class TestVllmMaxNSequences:
def test_default_value(self):
"""Test that VLLM_MAX_N_SEQUENCES defaults to 64."""
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("VLLM_MAX_N_SEQUENCES", None)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
assert envs.VLLM_MAX_N_SEQUENCES == 16384
def test_custom_value(self, monkeypatch: pytest.MonkeyPatch):
"""Test that VLLM_MAX_N_SEQUENCES can be overridden."""
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
assert envs.VLLM_MAX_N_SEQUENCES == 128
def test_sampling_params_respects_limit(
self,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that SamplingParams rejects n above the limit."""
from vllm.sampling_params import SamplingParams
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
max_n = envs.VLLM_MAX_N_SEQUENCES
SamplingParams(n=max_n)
with pytest.raises(ValueError, match="n must be at most"):
SamplingParams(n=max_n + 1)
def test_sampling_params_respects_custom_limit(
self,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that SamplingParams uses the overridden env var limit."""
from vllm.sampling_params import SamplingParams
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
SamplingParams(n=128)
with pytest.raises(ValueError, match="n must be at most 128"):
SamplingParams(n=129)