[CI/Build] Avoid CUDA initialization (#8534)
This commit is contained in:
@@ -5,6 +5,7 @@ from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def selective_state_update_ref(state,
|
||||
@@ -186,7 +187,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
|
||||
rtolw = max(rtolw, rtol)
|
||||
atolw = max(atolw, atol)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed_everything(0)
|
||||
batch_size = 2
|
||||
dim = 4
|
||||
dstate = 8
|
||||
@@ -287,7 +288,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
if torch.version.hip:
|
||||
atol *= 2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed_everything(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
|
||||
Reference in New Issue
Block a user