feat(attention): extract KV-cache update from FlexAttention backend (#36263)
Signed-off-by: cong-or <conchubhar.gannon@gmail.com>
This commit is contained in:
@@ -82,6 +82,8 @@ class FlexAttentionBackend(AttentionBackend):
|
|||||||
]
|
]
|
||||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
|
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
|
||||||
|
|
||||||
|
forward_includes_kv_cache_update: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "FLEX_ATTENTION"
|
return "FLEX_ATTENTION"
|
||||||
@@ -827,6 +829,29 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
assert tensor.ndim == 3
|
assert tensor.ndim == 3
|
||||||
return tensor[None, :, :, :]
|
return tensor[None, :, :, :]
|
||||||
|
|
||||||
|
def do_kv_cache_update(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
if self.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
return
|
||||||
|
|
||||||
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -908,17 +933,6 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
assert self.attn_type == AttentionType.DECODER
|
assert self.attn_type == AttentionType.DECODER
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# View out the block_size dim
|
# View out the block_size dim
|
||||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user