[Bugfix] Ensure calculated KV scales are applied in attention. (#27232)

Signed-off-by: adabeyta <aabeyta@redhat.com>
This commit is contained in:
Adrian Abeyta
2025-11-10 17:42:37 -06:00
committed by GitHub
parent b30372cbd0
commit a5a790eea6
4 changed files with 29 additions and 36 deletions

View File

@@ -745,6 +745,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@@ -752,12 +755,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# Mirror Attention.forward scale calculation path
if self.calculate_kv_scales and getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
self.calc_kv_scales(q, kv_c_normed, k_pe)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
@@ -786,14 +783,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
return output
else:
# We can still access forward context to check calculation flag
if self.calculate_kv_scales:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
if getattr(attn_metadata, "enable_kv_scales_calculation", False):
self.calc_kv_scales(q, kv_c_normed, k_pe)
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
@@ -881,17 +870,13 @@ def maybe_calc_kv_scales(
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
if attn_metadata is None or not getattr(
attn_metadata, "enable_kv_scales_calculation", False
):
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if not self.calculate_kv_scales:
return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)

View File

@@ -279,6 +279,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
self.max_model_len = model_config.max_model_len
# Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens
@@ -2625,16 +2628,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
# Set cudagraph mode to none if calc_kv_scales is true.
if attn_metadata is not None:
metadata_list = (
attn_metadata.values()
if isinstance(attn_metadata, dict)
else [attn_metadata]
)
if any(
getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list
):
cudagraph_runtime_mode = CUDAGraphMode.NONE
# KV scales calculation involves dynamic operations that are incompatible
# with CUDA graph capture.
if self.calculate_kv_scales:
cudagraph_runtime_mode = CUDAGraphMode.NONE
# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False
# Run the model.
# Use persistent buffers for CUDA graphs.