[Hardware] using current_platform.seed_everything (#9785)

Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
wangshuai09
2024-10-29 22:47:44 +08:00
committed by GitHub
parent 09500f7dde
commit 622b7ab955
27 changed files with 104 additions and 105 deletions

View File

@@ -9,7 +9,7 @@ from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything
from vllm.platforms import current_platform
def causal_conv1d_ref(
@@ -70,7 +70,7 @@ def causal_conv1d_update_ref(x,
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
@@ -161,7 +161,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
current_platform.seed_everything(0)
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
@@ -223,7 +223,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
current_platform.seed_everything(0)
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
@@ -270,7 +270,7 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
current_platform.seed_everything(0)
batch_size = 3
padding = 5 if with_padding else 0
@@ -343,7 +343,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
current_platform.seed_everything(0)
seqlens = []
batch_size = 4
if seqlen < 10: