[Hardware] using current_platform.seed_everything (#9785)

Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
wangshuai09
2024-10-29 22:47:44 +08:00
committed by GitHub
parent 09500f7dde
commit 622b7ab955
27 changed files with 104 additions and 105 deletions

View File

@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
from vllm.platforms import current_platform
def selective_state_update_ref(state,
@@ -235,7 +235,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
seed_everything(0)
current_platform.seed_everything(0)
batch_size = 1
dim = 4
dstate = 8
@@ -358,7 +358,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
if torch.version.hip:
atol *= 2
# set seed
seed_everything(0)
current_platform.seed_everything(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)