[FIX] Support non-zero CUDA devices in custom kernels (#1959)

This commit is contained in:
Jee Li
2024-01-03 11:09:59 +08:00
committed by GitHub
parent 4934d49274
commit 77af974b40
12 changed files with 74 additions and 30 deletions

View File

@@ -14,6 +14,7 @@ BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@@ -24,6 +25,7 @@ SEEDS = [0]
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_copy_blocks(
kv_cache_factory,
@@ -35,11 +37,12 @@ def test_copy_blocks(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
@@ -56,7 +59,7 @@ def test_copy_blocks(
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads,
head_size, dtype, seed)
head_size, dtype, seed, gpu_id)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
@@ -88,6 +91,7 @@ def test_copy_blocks(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_reshape_and_cache(
kv_cache_factory,
@@ -98,28 +102,29 @@ def test_reshape_and_cache(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda")
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device="cuda")
device=gpu_id)
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype,
seed)
seed, gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches.