diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 0215ec1a0..c77fd09de 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -929,13 +929,15 @@ def unified_mla_kv_cache_update( the data dependency between them to ensure torch.compile preserves ordering. """ forward_context = get_forward_context() - if forward_context.attn_metadata is None: - # Dummy/profile forwards should not update live KV cache pages. - return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) - attn_layer = forward_context.no_compile_layers[layer_name] kv_cache = attn_layer.kv_cache + # This needs to run even when we don't have metadata yet, so that the op + # is correctly captured. + if kv_cache.numel() == 0: + # Can't update an empty KV cache. + return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) + slot_mapping = forward_context.slot_mapping assert isinstance(slot_mapping, dict), ( f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "