Vectorize paged KV cache read/write, kill container
This commit is contained in:
@@ -69,24 +69,36 @@ def paged_kv_write(kv_data, slot_mapping, cache, block_size):
|
||||
else:
|
||||
kv_to_write = kv_data
|
||||
|
||||
for t in range(kv_data.shape[0]):
|
||||
slot = slot_mapping[t].item()
|
||||
block_idx = slot // block_size
|
||||
offset = slot % block_size
|
||||
if block_idx < cache.shape[0] and offset < cache.shape[1]:
|
||||
cache[block_idx, offset] = kv_to_write[t]
|
||||
# Vectorized write using advanced indexing
|
||||
block_indices = slot_mapping // block_size
|
||||
offsets = slot_mapping % block_size
|
||||
# Clamp to valid range (safety)
|
||||
valid = (block_indices < cache.shape[0]) & (offsets < cache.shape[1])
|
||||
if valid.all():
|
||||
cache[block_indices, offsets] = kv_to_write
|
||||
else:
|
||||
# Fall back to per-token for partial writes
|
||||
for t in range(kv_data.shape[0]):
|
||||
bi = block_indices[t].item()
|
||||
oi = offsets[t].item()
|
||||
if bi < cache.shape[0] and oi < cache.shape[1]:
|
||||
cache[bi, oi] = kv_to_write[t]
|
||||
|
||||
|
||||
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
|
||||
"""Read KV from paged cache. Returns fp8 or uint8."""
|
||||
"""Read KV from paged cache. Returns fp8 or uint8.
|
||||
|
||||
Vectorized version — uses advanced indexing instead of Python for loop.
|
||||
"""
|
||||
device = cache.device
|
||||
kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device)
|
||||
for t in range(num_tokens):
|
||||
slot = slot_mapping[t].item()
|
||||
block_idx = slot // block_size
|
||||
offset = slot % block_size
|
||||
if block_idx < cache.shape[0] and offset < cache.shape[1]:
|
||||
kv[t] = cache[block_idx, offset]
|
||||
# Compute block indices and offsets
|
||||
slots = slot_mapping # (num_tokens,)
|
||||
block_indices = slots // block_size
|
||||
offsets = slots % block_size
|
||||
|
||||
# Advanced indexing: cache[block_indices, offsets] -> (num_tokens, head_dim)
|
||||
kv = cache[block_indices, offsets]
|
||||
|
||||
# If cache is uint8, reinterpret as fp8
|
||||
if cache.dtype == torch.uint8:
|
||||
kv = kv.view(torch.float8_e4m3fn)
|
||||
|
||||
Reference in New Issue
Block a user