[Hardware] using current_platform.seed_everything (#9785)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def selective_state_update_ref(state,
|
||||
@@ -235,7 +235,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 1
|
||||
dim = 4
|
||||
dstate = 8
|
||||
@@ -358,7 +358,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
|
||||
Reference in New Issue
Block a user