From 255913fba44b03c0d670fe7735ac99bade0e8dab Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 15:48:16 +0000 Subject: [PATCH] Vectorize paged KV cache read/write, kill container --- vllm/patches/layers/csa_attention.py | 40 ++++++++++++++++++---------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/vllm/patches/layers/csa_attention.py b/vllm/patches/layers/csa_attention.py index 1e6e4e2b..f46bba47 100644 --- a/vllm/patches/layers/csa_attention.py +++ b/vllm/patches/layers/csa_attention.py @@ -69,24 +69,36 @@ def paged_kv_write(kv_data, slot_mapping, cache, block_size): else: kv_to_write = kv_data - for t in range(kv_data.shape[0]): - slot = slot_mapping[t].item() - block_idx = slot // block_size - offset = slot % block_size - if block_idx < cache.shape[0] and offset < cache.shape[1]: - cache[block_idx, offset] = kv_to_write[t] + # Vectorized write using advanced indexing + block_indices = slot_mapping // block_size + offsets = slot_mapping % block_size + # Clamp to valid range (safety) + valid = (block_indices < cache.shape[0]) & (offsets < cache.shape[1]) + if valid.all(): + cache[block_indices, offsets] = kv_to_write + else: + # Fall back to per-token for partial writes + for t in range(kv_data.shape[0]): + bi = block_indices[t].item() + oi = offsets[t].item() + if bi < cache.shape[0] and oi < cache.shape[1]: + cache[bi, oi] = kv_to_write[t] def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim): - """Read KV from paged cache. Returns fp8 or uint8.""" + """Read KV from paged cache. Returns fp8 or uint8. + + Vectorized version — uses advanced indexing instead of Python for loop. + """ device = cache.device - kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device) - for t in range(num_tokens): - slot = slot_mapping[t].item() - block_idx = slot // block_size - offset = slot % block_size - if block_idx < cache.shape[0] and offset < cache.shape[1]: - kv[t] = cache[block_idx, offset] + # Compute block indices and offsets + slots = slot_mapping # (num_tokens,) + block_indices = slots // block_size + offsets = slots % block_size + + # Advanced indexing: cache[block_indices, offsets] -> (num_tokens, head_dim) + kv = cache[block_indices, offsets] + # If cache is uint8, reinterpret as fp8 if cache.dtype == torch.uint8: kv = kv.view(torch.float8_e4m3fn)