[Tests] Standardize RNG seed utility across test files (#32982)
Signed-off-by: 7. Sun <jhao.sun@gmail.com>
This commit is contained in:
@@ -2,13 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for FlexAttention backend vs default backend"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from tests.utils import set_random_seed
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
@@ -27,15 +25,6 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""Set seeds for reproducibility"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
@@ -57,7 +46,7 @@ def test_flex_attention_vs_default_backend(vllm_runner):
|
||||
]
|
||||
|
||||
# Run with flex attention
|
||||
set_seed(seed)
|
||||
set_random_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
@@ -71,7 +60,7 @@ def test_flex_attention_vs_default_backend(vllm_runner):
|
||||
)
|
||||
|
||||
# Run with default backend
|
||||
set_seed(seed)
|
||||
set_random_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
|
||||
@@ -59,7 +59,10 @@ from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.mem_constants import GB_bytes
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.utils.torch_utils import (
|
||||
cuda_device_count_stateless,
|
||||
set_random_seed, # noqa: F401 - re-exported for use in test files
|
||||
)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from tests.utils import create_new_process_for_each_test, set_random_seed
|
||||
from tests.v1.logits_processors.utils import (
|
||||
DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
@@ -135,7 +134,7 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
|
||||
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
random.seed(40)
|
||||
set_random_seed(40)
|
||||
|
||||
# Choose LLM args based on logitproc source
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE:
|
||||
@@ -194,7 +193,7 @@ def test_custom_logitsprocs_req(monkeypatch):
|
||||
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
random.seed(40)
|
||||
set_random_seed(40)
|
||||
_run_test(
|
||||
{"logits_processors": [WrappedPerReqLogitsProcessor]}, logitproc_loaded=True
|
||||
)
|
||||
@@ -237,7 +236,7 @@ def test_rejects_custom_logitsprocs(
|
||||
logitproc from
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
random.seed(40)
|
||||
set_random_seed(40)
|
||||
|
||||
test_params: dict[str, dict[str, Any]] = {
|
||||
"pooling": {
|
||||
|
||||
@@ -333,6 +333,8 @@ def set_random_seed(seed: int | None) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def create_kv_caches_with_random_flash(
|
||||
|
||||
Reference in New Issue
Block a user