[Tests] Standardize RNG seed utility across test files (#32982)
Signed-off-by: 7. Sun <jhao.sun@gmail.com>
This commit is contained in:
@@ -2,13 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for FlexAttention backend vs default backend"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from tests.utils import set_random_seed
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
@@ -27,15 +25,6 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""Set seeds for reproducibility"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
@@ -57,7 +46,7 @@ def test_flex_attention_vs_default_backend(vllm_runner):
|
||||
]
|
||||
|
||||
# Run with flex attention
|
||||
set_seed(seed)
|
||||
set_random_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
@@ -71,7 +60,7 @@ def test_flex_attention_vs_default_backend(vllm_runner):
|
||||
)
|
||||
|
||||
# Run with default backend
|
||||
set_seed(seed)
|
||||
set_random_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
|
||||
Reference in New Issue
Block a user