[Platform] Deprecate seed_everything (#31659)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
awq_dequantize_triton,
|
||||
awq_gemm_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
device = "cuda"
|
||||
|
||||
@@ -86,7 +86,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
zeros_cols = qweight_cols
|
||||
zeros_dtype = torch.int32
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
set_random_seed(0)
|
||||
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
@@ -141,7 +141,7 @@ def test_gemm(N, K, M, splitK, group_size):
|
||||
qzeros_rows = scales_rows
|
||||
qzeros_cols = qweight_cols
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
set_random_seed(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device)
|
||||
qweight = torch.randint(
|
||||
|
||||
Reference in New Issue
Block a user