[Minor] More fix of test_cache.py CI test failure (#2750)
This commit is contained in:
@@ -181,16 +181,15 @@ def test_swap_blocks(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
src_device = f"{direction[0]}:{device}" if direction[
|
||||
0] == "cuda" else direction[0]
|
||||
dst_device = f"{direction[1]}:{device}" if direction[
|
||||
1] == "cuda" else direction[1]
|
||||
|
||||
src_device = device if direction[0] == "cuda" else 'cpu'
|
||||
dst_device = device if direction[1] == "cuda" else 'cpu'
|
||||
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
# For the same device, mapping must not overlap
|
||||
|
||||
Reference in New Issue
Block a user