[Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518)
This commit is contained in:
@@ -51,12 +51,12 @@ def test_contexted_kv_attention(
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
num_tokens = sum(subquery_lens)
|
||||
num_tokens = sum(query_lens)
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
@@ -75,15 +75,15 @@ def test_contexted_kv_attention(
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:BS * max_block_per_request].view(
|
||||
BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
@@ -92,7 +92,7 @@ def test_contexted_kv_attention(
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
for i in range(BS):
|
||||
for j in range(subquery_lens[i]):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||
j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||
@@ -178,7 +178,7 @@ def test_contexted_kv_attention(
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
subquery_lens, seq_lens)
|
||||
query_lens, seq_lens)
|
||||
if sliding_window > 0:
|
||||
attn_bias = attn_bias.make_local_attention_from_bottomright(
|
||||
sliding_window)
|
||||
|
||||
Reference in New Issue
Block a user