[Hardware] using current_platform.seed_everything (#9785)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -7,7 +7,6 @@ import gc
|
||||
import inspect
|
||||
import ipaddress
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
"""
|
||||
Set the seed of each random module.
|
||||
|
||||
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if current_platform.is_xpu():
|
||||
torch.xpu.manual_seed_all(seed)
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||
@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
|
||||
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||
)
|
||||
|
||||
seed_everything(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user