feat(attention): extract KV-cache update from FlashAttentionDiffKV ba… (#36466)

Signed-off-by: Prathmesh Bhatt <71340361+Prathmesh234@users.noreply.github.com>
This commit is contained in:
Prathmesh Bhatt
2026-03-30 16:16:09 -07:00
committed by GitHub
parent e812bf70bd
commit 93b3ec1585

View File

@@ -85,6 +85,40 @@ class FlashAttentionDiffKVBackend(FlashAttentionBackend):
class FlashAttentionDiffKVImpl(FlashAttentionImpl):
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
# Unlike standard FlashAttn which splits kv_cache via unbind(0),
# DiffKV packs K and V into a single tensor along the last dim:
# kv_cache shape: [num_blocks, block_size, num_kv_heads,
# head_size_k + head_size_v]
# The triton kernel handles this combined layout directly.
#
# NOTE(woosuk): key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
triton_reshape_and_cache_flash_diffkv(
key,
value,
kv_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def forward(
self,
layer: torch.nn.Module,
@@ -157,33 +191,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
# kv_cache update for different head_size K and V
triton_reshape_and_cache_flash_diffkv(
key,
value,
kv_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(