[CI/Build] Avoid CUDA initialization (#8534)
This commit is contained in:
@@ -7,6 +7,7 @@ from einops import rearrange
|
||||
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
@@ -104,7 +105,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed_everything(0)
|
||||
if not channel_last:
|
||||
x = torch.randn(batch,
|
||||
4096 + dim + 64,
|
||||
@@ -175,7 +176,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
seed_everything(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, device=device, dtype=itype)
|
||||
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
|
||||
|
||||
Reference in New Issue
Block a user