[Performance] Extract KV-cache update from TreeAttention backend (#35384)
Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
This commit is contained in:
@@ -31,6 +31,7 @@ logger = init_logger(__name__)
|
|||||||
class TreeAttentionBackend(AttentionBackend):
|
class TreeAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||||
|
forward_includes_kv_cache_update: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
@@ -326,6 +327,33 @@ class TreeAttentionImpl(AttentionImpl):
|
|||||||
"TreeAttentionImpl."
|
"TreeAttentionImpl."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def do_kv_cache_update(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||||
|
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||||
|
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||||
|
# op uses the slot_mapping's shape to determine the number of
|
||||||
|
# actual tokens.
|
||||||
|
ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -361,26 +389,7 @@ class TreeAttentionImpl(AttentionImpl):
|
|||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
# Cache the input KVs.
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
if self.kv_sharing_target_layer_name is None:
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
|
||||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
|
||||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
|
||||||
# op uses the slot_mapping's shape to determine the number of
|
|
||||||
# actual tokens.
|
|
||||||
ops.reshape_and_cache_flash(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user