feat(attention): extract KV-cache update from FlashAttentionDiffKV ba… (#36466)
Signed-off-by: Prathmesh Bhatt <71340361+Prathmesh234@users.noreply.github.com>
This commit is contained in:
@@ -85,6 +85,40 @@ class FlashAttentionDiffKVBackend(FlashAttentionBackend):
|
||||
|
||||
|
||||
class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return
|
||||
|
||||
# Unlike standard FlashAttn which splits kv_cache via unbind(0),
|
||||
# DiffKV packs K and V into a single tensor along the last dim:
|
||||
# kv_cache shape: [num_blocks, block_size, num_kv_heads,
|
||||
# head_size_k + head_size_v]
|
||||
# The triton kernel handles this combined layout directly.
|
||||
#
|
||||
# NOTE(woosuk): key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -157,33 +191,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
key_cache = kv_cache[..., : self.head_size]
|
||||
value_cache = kv_cache[..., self.head_size :]
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
|
||||
# kv_cache update for different head_size K and V
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
# queries are quantized in the attention layer
|
||||
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
||||
|
||||
Reference in New Issue
Block a user