feat(attention): extract KV-cache update from FlexAttention backend (#36263)
Signed-off-by: cong-or <conchubhar.gannon@gmail.com>
This commit is contained in:
@@ -82,6 +82,8 @@ class FlexAttentionBackend(AttentionBackend):
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLEX_ATTENTION"
|
||||
@@ -827,6 +829,29 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
assert tensor.ndim == 3
|
||||
return tensor[None, :, :, :]
|
||||
|
||||
def do_kv_cache_update(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
if self.attn_type == AttentionType.ENCODER_ONLY:
|
||||
return
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -908,17 +933,6 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
Reference in New Issue
Block a user