[Hardware] using current_platform.seed_everything (#9785)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import scaled_int8_quant
|
||||
from vllm.utils import seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
@@ -45,7 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
|
||||
@@ -68,7 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
@@ -112,7 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
@@ -138,7 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float, azp: int) -> None:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
|
||||
Reference in New Issue
Block a user