[Platform] Deprecate seed_everything (#31659)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn,
|
||||
selective_state_update,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
|
||||
def selective_state_update_ref(
|
||||
@@ -271,7 +271,7 @@ def test_selective_scan(
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
set_random_seed(0)
|
||||
batch_size = 1
|
||||
dim = 4
|
||||
dstate = 8
|
||||
@@ -401,7 +401,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
set_random_seed(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
@@ -438,7 +438,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
current_platform.seed_everything(0)
|
||||
set_random_seed(0)
|
||||
batch_size = 4
|
||||
token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
|
||||
total_tokens = int(token_counts.sum().item())
|
||||
@@ -857,7 +857,7 @@ def test_selective_state_update_with_num_accepted_tokens(
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
set_random_seed(0)
|
||||
batch_size = 4
|
||||
|
||||
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
|
||||
@@ -983,7 +983,7 @@ def test_selective_state_update_varlen_with_num_accepted(
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
set_random_seed(0)
|
||||
batch_size = 4
|
||||
|
||||
tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
|
||||
|
||||
Reference in New Issue
Block a user