[Bugfix] Add missing encoder only guard for do_kv_cache_update (#33269)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
4197168ea5
commit
ab597c869a
@@ -572,6 +572,10 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||||
|
# For encoder attention,
|
||||||
|
# we use direct Q, K, V tensors without caching
|
||||||
|
return
|
||||||
# For decoder and cross-attention, use KV cache as before
|
# For decoder and cross-attention, use KV cache as before
|
||||||
key_cache, value_cache = kv_cache.unbind(1)
|
key_cache, value_cache = kv_cache.unbind(1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user