Implement single_query_cached_kv_attention kernel (#3)
This commit is contained in:
@@ -3,7 +3,8 @@ from typing import List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow import ops
|
||||
from cacheflow import attention_ops
|
||||
from cacheflow import cache_ops
|
||||
from cacheflow.models import InputMetadata
|
||||
|
||||
|
||||
@@ -11,7 +12,7 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.scale = float(scale)
|
||||
|
||||
def _masked_attention(
|
||||
self,
|
||||
@@ -57,38 +58,21 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[3]
|
||||
block_size = value_cache.shape[2]
|
||||
block_tables = input_metadata.block_tables
|
||||
|
||||
# FIXME(woosuk): Replace the following with a custom op.
|
||||
for i in range(input_metadata.num_generation_tokens):
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables[i]
|
||||
context_len = int(input_metadata.context_lens[i])
|
||||
|
||||
keys = []
|
||||
values = []
|
||||
for j in range(context_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_heads, head_size)
|
||||
keys.append(k)
|
||||
|
||||
v = value_cache[block_number, :, block_offset, :]
|
||||
values.append(v)
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
|
||||
out = self._masked_attention(q, keys, values)
|
||||
out = out.view(num_heads, head_size)
|
||||
output[i].copy_(out, non_blocking=True)
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -96,7 +80,7 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
@@ -110,7 +94,7 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
# Reshape the input tensors.
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[3]
|
||||
head_size = value_cache.shape[2]
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_heads, head_size)
|
||||
value = value.view(-1, num_heads, head_size)
|
||||
@@ -125,7 +109,7 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
cache_event.wait()
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
ops.reshape_and_cache(
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, input_metadata.slot_mapping)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
|
||||
Reference in New Issue
Block a user