Remove hardcoded device="cuda" to support more devices (#2503)

Co-authored-by: Jiang Li <jiang1.li@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2024-02-02 07:46:39 +08:00
committed by GitHub
parent c410f5d020
commit 96b6f475dd
32 changed files with 343 additions and 292 deletions

View File

@@ -17,7 +17,9 @@ BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
@@ -29,7 +31,7 @@ KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks(
@@ -42,13 +44,14 @@ def test_copy_blocks(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
kv_cache_dtype: str,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
@@ -66,7 +69,7 @@ def test_copy_blocks(
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads,
head_size, kv_cache_dtype,
dtype, seed, gpu_id)
dtype, seed, device)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
@@ -98,7 +101,7 @@ def test_copy_blocks(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_reshape_and_cache(
kv_cache_factory,
@@ -109,29 +112,25 @@ def test_reshape_and_cache(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device=gpu_id)
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype,
None, seed, gpu_id)
None, seed, device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches.
@@ -166,7 +165,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_swap_blocks(
kv_cache_factory,
@@ -182,7 +181,8 @@ def test_swap_blocks(
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
src_device = f"{direction[0]}:{device}" if direction[
0] == "cuda" else direction[0]
dst_device = f"{direction[1]}:{device}" if direction[