Enable GQA support in the prefix prefill kernels (#3007)
Signed-off-by: Tao He <sighingnow@gmail.com>
This commit is contained in:
@@ -8,7 +8,8 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
NUM_HEADS = [12]
|
||||
NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 8, 64]
|
||||
HEAD_SIZES = [128]
|
||||
DTYPES = [torch.float16]
|
||||
CUDA_DEVICES = [
|
||||
@@ -17,12 +18,14 @@ CUDA_DEVICES = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
@@ -41,28 +44,29 @@ def test_contexted_kv_attention(
|
||||
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
num_tokens = sum(subquery_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
|
||||
kv = torch.empty(sum(seq_lens), 2, num_heads, head_size, dtype=dtype)
|
||||
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
|
||||
k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:BS * max_block_per_request].view(
|
||||
@@ -93,19 +97,21 @@ def test_contexted_kv_attention(
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc])
|
||||
v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc])
|
||||
k_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc])
|
||||
v_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc])
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||
k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8,
|
||||
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
||||
8).permute(0, 2, 3, 1, 4).contiguous()
|
||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_heads,
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring generation time
|
||||
@@ -123,12 +129,29 @@ def test_contexted_kv_attention(
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
if num_kv_heads != num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
||||
query.shape[-1])
|
||||
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, value.shape[-1])
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
subquery_lens, seq_lens)
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
@@ -137,9 +160,9 @@ def test_contexted_kv_attention(
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
@@ -148,5 +171,5 @@ def test_contexted_kv_attention(
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
output_ref = output_ref.squeeze(0)
|
||||
output_ref = output_ref.squeeze(0, 2)
|
||||
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
|
||||
|
||||
Reference in New Issue
Block a user