Fix attention

This commit is contained in:
Woosuk Kwon
2023-02-23 23:02:25 +00:00
parent ba84b8728a
commit 932844f1cd
2 changed files with 21 additions and 6 deletions

View File

@@ -53,20 +53,19 @@ class OPTCacheFlowAttention(nn.Module):
context_len = int(input_metadata.context_lens[i])
keys = []
values = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size)
keys.append(k)
keys = torch.stack(keys, dim=0)
values = []
for j in range(context_len):
block_number = block_table[j // block_size]
block_offset = j % block_size
v = value_cache[block_number, :, block_offset, :]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
q = q.unsqueeze(0)
@@ -87,6 +86,11 @@ class OPTCacheFlowAttention(nn.Module):
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Prune out invalid tokens.
query = query[:input_metadata.num_valid_tokens]
key = key[:input_metadata.num_valid_tokens]
value = value[:input_metadata.num_valid_tokens]
# Reshape the input tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]