[Hardware] using current_platform.seed_everything (#9785)

Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
wangshuai09
2024-10-29 22:47:44 +08:00
committed by GitHub
parent 09500f7dde
commit 622b7ab955
27 changed files with 104 additions and 105 deletions

View File

@@ -7,7 +7,6 @@ import gc
import inspect
import ipaddress
import os
import random
import socket
import subprocess
import sys
@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
return psutil.virtual_memory().total
def seed_everything(seed: int) -> None:
"""
Set the seed of each random module.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
if current_platform.is_cuda_alike():
torch.cuda.manual_seed_all(seed)
if current_platform.is_xpu():
torch.xpu.manual_seed_all(seed)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
seed_everything(seed)
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
seed_everything(seed)
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)