[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)

This commit is contained in:
SangBin Cho
2024-05-04 02:20:12 +09:00
committed by GitHub
parent 2d7bce9cd5
commit 3521ba4f25
27 changed files with 554 additions and 525 deletions

View File

@@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> None:
@@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention(
num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist()
context_lens = context_lens.cpu().tolist()
seq_lens = seq_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
context_len = int(context_lens[i])
seq_len = int(seq_lens[i])
keys = []
values = []
for j in range(context_len):
for j in range(seq_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
@@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len).int()
alibi_bias = (position_ids - context_len + 1).float()
position_ids = torch.arange(seq_len).int()
alibi_bias = (position_ids - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
@@ -149,13 +149,13 @@ def test_paged_attention(
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int)
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
@@ -186,16 +186,15 @@ def test_paged_attention(
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
elif version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
@@ -218,9 +217,9 @@ def test_paged_attention(
num_kv_heads,
scale,
block_tables,
context_lens,
seq_lens,
block_size,
max_context_len,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
@@ -255,7 +254,7 @@ def test_paged_attention(
key_cache,
value_cache,
block_tables,
context_lens,
seq_lens,
scale,
alibi_slopes,
)

View File

@@ -51,12 +51,12 @@ def test_contexted_kv_attention(
cache_size = 640
block_size = 32
max_block_per_request = 64
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(subquery_lens)
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
@@ -75,15 +75,15 @@ def test_contexted_kv_attention(
num_kv_heads,
head_size,
dtype=dtype)
k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
dim=0)
max_input_len = MAX_SEQ_LEN
@@ -92,7 +92,7 @@ def test_contexted_kv_attention(
dtype=torch.long),
dim=0)
for i in range(BS):
for j in range(subquery_lens[i]):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
@@ -178,7 +178,7 @@ def test_contexted_kv_attention(
value = value.unsqueeze(0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
subquery_lens, seq_lens)
query_lens, seq_lens)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(
sliding_window)