[Performance] Extract KV cache update op from flashinfer forward (#35422)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
ElizaWszola
2026-02-26 22:29:01 +01:00
committed by GitHub
parent 967572dd5f
commit 98217b09f9

View File

@@ -381,6 +381,8 @@ class FlashInferBackend(AttentionBackend):
return "HND"
return None
forward_includes_kv_cache_update: bool = False
@dataclass
class FIPrefill:
@@ -1330,32 +1332,15 @@ class FlashInferImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
"fp8"
):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
)
kv_cache = kv_cache.view(torch_dtype)
kv_cache = kv_cache.view(torch_dtype)
# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
@@ -1599,6 +1584,33 @@ class FlashInferImpl(AttentionImpl):
)
return output_padded
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def fast_plan_decode(
self, # decode wrapper