Implement single_query_cached_kv_attention kernel (#3)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from cacheflow import ops
|
||||
from cacheflow import cache_ops
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@@ -57,20 +57,22 @@ class CacheEngine:
|
||||
def get_value_block_shape(self) -> Tuple[int, int, int]:
|
||||
return (
|
||||
self.num_heads,
|
||||
self.block_size,
|
||||
self.head_size,
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
def allocate_gpu_cache(self) -> List[KVCache]:
|
||||
gpu_cache: List[KVCache] = []
|
||||
key_block_shape = self.get_key_block_shape()
|
||||
value_block_shape = self.get_value_block_shape()
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *self.get_key_block_shape()),
|
||||
size=(self.num_gpu_blocks, *key_block_shape),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *self.get_value_block_shape()),
|
||||
size=(self.num_gpu_blocks, *value_block_shape),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
)
|
||||
@@ -79,14 +81,16 @@ class CacheEngine:
|
||||
|
||||
def allocate_cpu_cache(self) -> List[KVCache]:
|
||||
cpu_cache: List[KVCache] = []
|
||||
key_block_shape = self.get_key_block_shape()
|
||||
value_block_shape = self.get_value_block_shape()
|
||||
for _ in range(self.num_layers):
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_cpu_blocks, *self.get_key_block_shape()),
|
||||
size=(self.num_cpu_blocks, *key_block_shape),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_cpu_blocks, *self.get_value_block_shape()),
|
||||
size=(self.num_cpu_blocks, *value_block_shape),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
)
|
||||
@@ -104,10 +108,10 @@ class CacheEngine:
|
||||
src_key_cache, src_value_cache = src[i]
|
||||
dst_key_cache, dst_value_cache = dst[i]
|
||||
# Copy the key blocks.
|
||||
ops.copy_cache_blocks(
|
||||
cache_ops.copy_cache_blocks(
|
||||
src_key_cache, dst_key_cache, src_to_dst)
|
||||
# Copy the value blocks.
|
||||
ops.copy_cache_blocks(
|
||||
cache_ops.copy_cache_blocks(
|
||||
src_value_cache, dst_value_cache, src_to_dst)
|
||||
event = self.events[i]
|
||||
event.record(stream=self.cache_stream)
|
||||
|
||||
@@ -118,7 +118,7 @@ class Worker:
|
||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||
for block_table in generation_block_tables]
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_block_tables, dtype=int, device=self.device)
|
||||
padded_block_tables, dtype=torch.int, device=self.device)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_ids=prompt_seq_ids + generation_seq_ids,
|
||||
|
||||
Reference in New Issue
Block a user