Implement single_query_cached_kv_attention kernel (#3)

This commit is contained in:
Woosuk Kwon
2023-03-01 15:02:19 -08:00
committed by GitHub
parent cbf8779afa
commit 0deacbce6e
12 changed files with 2140 additions and 60 deletions

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, Tuple
import torch
from cacheflow import ops
from cacheflow import cache_ops
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -57,20 +57,22 @@ class CacheEngine:
def get_value_block_shape(self) -> Tuple[int, int, int]:
return (
self.num_heads,
self.block_size,
self.head_size,
self.block_size,
)
def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_key_block_shape()),
size=(self.num_gpu_blocks, *key_block_shape),
dtype=self.dtype,
device=self.gpu_id,
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_value_block_shape()),
size=(self.num_gpu_blocks, *value_block_shape),
dtype=self.dtype,
device=self.gpu_id,
)
@@ -79,14 +81,16 @@ class CacheEngine:
def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_key_block_shape()),
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype,
pin_memory=True,
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_value_block_shape()),
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype,
pin_memory=True,
)
@@ -104,10 +108,10 @@ class CacheEngine:
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
ops.copy_cache_blocks(
cache_ops.copy_cache_blocks(
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
ops.copy_cache_blocks(
cache_ops.copy_cache_blocks(
src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)

View File

@@ -118,7 +118,7 @@ class Worker:
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=int, device=self.device)
padded_block_tables, dtype=torch.int, device=self.device)
input_metadata = InputMetadata(
seq_ids=prompt_seq_ids + generation_seq_ids,