fix(security): Add VLLM_MAX_N_SEQUENCES environment variable and enforce limit (#37952)
Signed-off-by: jperezde <jperezde@redhat.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
committed by
GitHub
parent
497e234d38
commit
b111f8a61f
@@ -454,3 +454,55 @@ class TestVllmConfigureLogging:
|
||||
|
||||
with pytest.raises(ValueError, match="invalid literal for int"):
|
||||
_ = envs.VLLM_CONFIGURE_LOGGING
|
||||
|
||||
|
||||
class TestVllmMaxNSequences:
|
||||
def test_default_value(self):
|
||||
"""Test that VLLM_MAX_N_SEQUENCES defaults to 64."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("VLLM_MAX_N_SEQUENCES", None)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
assert envs.VLLM_MAX_N_SEQUENCES == 16384
|
||||
|
||||
def test_custom_value(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that VLLM_MAX_N_SEQUENCES can be overridden."""
|
||||
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
assert envs.VLLM_MAX_N_SEQUENCES == 128
|
||||
|
||||
def test_sampling_params_respects_limit(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that SamplingParams rejects n above the limit."""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
max_n = envs.VLLM_MAX_N_SEQUENCES
|
||||
SamplingParams(n=max_n)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most"):
|
||||
SamplingParams(n=max_n + 1)
|
||||
|
||||
def test_sampling_params_respects_custom_limit(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that SamplingParams uses the overridden env var limit."""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
SamplingParams(n=128)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most 128"):
|
||||
SamplingParams(n=129)
|
||||
|
||||
Reference in New Issue
Block a user