feat(attention): extract KV-cache update from FlexAttention backend (#36263)

Signed-off-by: cong-or <conchubhar.gannon@gmail.com>
This commit is contained in:
cong-or
2026-03-09 03:40:12 +00:00
committed by GitHub
parent d62856b928
commit 747431044d

View File

@@ -82,6 +82,8 @@ class FlexAttentionBackend(AttentionBackend):
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "FLEX_ATTENTION"
@@ -827,6 +829,29 @@ class FlexAttentionImpl(AttentionImpl):
assert tensor.ndim == 3
return tensor[None, :, :, :]
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.attn_type == AttentionType.ENCODER_ONLY:
return
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def forward(
self,
layer: torch.nn.Module,
@@ -908,17 +933,6 @@ class FlexAttentionImpl(AttentionImpl):
assert self.attn_type == AttentionType.DECODER
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)