Fix MLA runs when use_inductor_graph_partition=True (#38631)
Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@@ -929,13 +929,15 @@ def unified_mla_kv_cache_update(
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.attn_metadata is None:
|
||||
# Dummy/profile forwards should not update live KV cache pages.
|
||||
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
|
||||
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache
|
||||
|
||||
# This needs to run even when we don't have metadata yet, so that the op
|
||||
# is correctly captured.
|
||||
if kv_cache.numel() == 0:
|
||||
# Can't update an empty KV cache.
|
||||
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
|
||||
Reference in New Issue
Block a user