Vectorize paged KV cache read/write, kill container

This commit is contained in:
2026-05-19 15:48:16 +00:00
parent 8b2cb41160
commit 255913fba4

View File

@@ -69,24 +69,36 @@ def paged_kv_write(kv_data, slot_mapping, cache, block_size):
else:
kv_to_write = kv_data
for t in range(kv_data.shape[0]):
slot = slot_mapping[t].item()
block_idx = slot // block_size
offset = slot % block_size
if block_idx < cache.shape[0] and offset < cache.shape[1]:
cache[block_idx, offset] = kv_to_write[t]
# Vectorized write using advanced indexing
block_indices = slot_mapping // block_size
offsets = slot_mapping % block_size
# Clamp to valid range (safety)
valid = (block_indices < cache.shape[0]) & (offsets < cache.shape[1])
if valid.all():
cache[block_indices, offsets] = kv_to_write
else:
# Fall back to per-token for partial writes
for t in range(kv_data.shape[0]):
bi = block_indices[t].item()
oi = offsets[t].item()
if bi < cache.shape[0] and oi < cache.shape[1]:
cache[bi, oi] = kv_to_write[t]
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
"""Read KV from paged cache. Returns fp8 or uint8."""
"""Read KV from paged cache. Returns fp8 or uint8.
Vectorized version — uses advanced indexing instead of Python for loop.
"""
device = cache.device
kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device)
for t in range(num_tokens):
slot = slot_mapping[t].item()
block_idx = slot // block_size
offset = slot % block_size
if block_idx < cache.shape[0] and offset < cache.shape[1]:
kv[t] = cache[block_idx, offset]
# Compute block indices and offsets
slots = slot_mapping # (num_tokens,)
block_indices = slots // block_size
offsets = slots % block_size
# Advanced indexing: cache[block_indices, offsets] -> (num_tokens, head_dim)
kv = cache[block_indices, offsets]
# If cache is uint8, reinterpret as fp8
if cache.dtype == torch.uint8:
kv = kv.view(torch.float8_e4m3fn)