[Hardware] using current_platform.seed_everything (#9785)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from vllm import _custom_ops as ops # noqa: F401
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def causal_conv1d_ref(
|
||||
@@ -70,7 +70,7 @@ def causal_conv1d_update_ref(x,
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the
|
||||
The conv_state will be updated by copying x to the
|
||||
conv_state starting at the index
|
||||
@cache_seqlens % state_len before performing the convolution.
|
||||
|
||||
@@ -161,7 +161,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
x = torch.randn(batch, dim, seqlen, device=device,
|
||||
dtype=itype).contiguous()
|
||||
|
||||
@@ -223,7 +223,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
batch = 2
|
||||
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
|
||||
x_ref = x.clone()
|
||||
@@ -270,7 +270,7 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
batch_size = 3
|
||||
padding = 5 if with_padding else 0
|
||||
@@ -343,7 +343,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
|
||||
if itype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
# set seed
|
||||
seed_everything(0)
|
||||
current_platform.seed_everything(0)
|
||||
seqlens = []
|
||||
batch_size = 4
|
||||
if seqlen < 10:
|
||||
|
||||
Reference in New Issue
Block a user