fix(security): Add VLLM_MAX_N_SEQUENCES environment variable and enforce limit (#37952)
Signed-off-by: jperezde <jperezde@redhat.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
committed by
GitHub
parent
497e234d38
commit
b111f8a61f
@@ -1020,3 +1020,114 @@ def test_chat_completion_request_n_parameter_various_values():
|
||||
assert sampling_params.n == n_value, (
|
||||
f"Expected n={n_value}, got n={sampling_params.n}"
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_exceeds_default_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that n values exceeding the default limit are rejected."""
|
||||
import vllm.envs as envs
|
||||
|
||||
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
max_n = envs.VLLM_MAX_N_SEQUENCES
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=max_n + 1,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most"):
|
||||
request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that n at exactly the limit is accepted."""
|
||||
import vllm.envs as envs
|
||||
|
||||
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
max_n = envs.VLLM_MAX_N_SEQUENCES
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=max_n,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
assert sampling_params.n == max_n
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_custom_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that VLLM_MAX_N_SEQUENCES env var overrides the default limit."""
|
||||
import vllm.envs as envs
|
||||
|
||||
monkeypatch.setenv("VLLM_MAX_N_SEQUENCES", "128")
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=128,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
assert sampling_params.n == 128
|
||||
|
||||
request_over = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=129,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most 128"):
|
||||
request_over.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_massive_value(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that astronomically large n values are rejected (CVE fix)."""
|
||||
import vllm.envs as envs
|
||||
|
||||
monkeypatch.delenv("VLLM_MAX_N_SEQUENCES", raising=False)
|
||||
if hasattr(envs.__getattr__, "cache_clear"):
|
||||
envs.__getattr__.cache_clear()
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=100_000_000,
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="n must be at most"):
|
||||
request.to_sampling_params(
|
||||
max_tokens=1,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user