Remove hardcoded device="cuda" to support more devices (#2503)

Co-authored-by: Jiang Li <jiang1.li@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2024-02-02 07:46:39 +08:00
committed by GitHub
parent c410f5d020
commit 96b6f475dd
32 changed files with 343 additions and 292 deletions

View File

@@ -27,7 +27,9 @@ BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def ref_masked_attention(
@@ -91,7 +93,7 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device=query.device).int()
position_ids = torch.arange(context_len).int()
alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
@@ -110,7 +112,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_paged_attention(
kv_cache_factory,
version: str,
@@ -122,33 +124,28 @@ def test_paged_attention(
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device=gpu_id)
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device=gpu_id)
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
context_lens = torch.tensor(context_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
@@ -159,13 +156,13 @@ def test_paged_attention(
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
block_tables = torch.tensor(block_tables, dtype=torch.int)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
gpu_id)
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Call the paged attention kernel.
@@ -193,12 +190,10 @@ def test_paged_attention(
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
@@ -229,14 +224,14 @@ def test_paged_attention(
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=gpu_id)
device=device)
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=gpu_id)
device=device)
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
@@ -283,7 +278,7 @@ def ref_multi_query_kv_attention(
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
attn_mask = attn_mask.to(dtype=dtype)
ref_output = ref_masked_attention(
query[start_idx:end_idx],
@@ -303,7 +298,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
@@ -311,12 +306,13 @@ def test_multi_query_kv_attention(
head_size: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
@@ -329,8 +325,7 @@ def test_multi_query_kv_attention(
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype,
device=gpu_id)
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)