diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index c547d5d3f..8438b460f 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -105,6 +105,7 @@ def create_whisper_attention_backend_with_block_pooling( ) -> type[AttentionBackend]: prefix = "WhisperCausalAttentionWithBlockPooling_" underlying_builder = underlying_attn_backend.get_builder_cls() + underlying_impl = underlying_attn_backend.get_impl_cls() class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore def __init__( @@ -151,6 +152,43 @@ def create_whisper_attention_backend_with_block_pooling( common_prefix_len, new_common_attn_metadata, fast_build ) + # NOTE: We need a custom impl so we can use the transformed slot_mapping + # computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of + # the one from `forward_context.slot_mapping` (gpu_model_runner). + # This follows the same pattern as CrossAttentionImpl. + class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc] + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + if ( + not underlying_attn_backend.forward_includes_kv_cache_update + and attn_metadata is not None + ): + self.do_kv_cache_update( + layer, key, value, kv_cache, attn_metadata.slot_mapping + ) + + return super().forward( + layer, + query, + key, + value, + kv_cache, + attn_metadata, + output, + output_scale, + output_block_scale, + ) + if not issubclass(underlying_attn_backend, FlashAttentionBackend): raise NotImplementedError( f"{underlying_attn_backend} is not yet supported." @@ -163,6 +201,7 @@ def create_whisper_attention_backend_with_block_pooling( attention_backend_cls=underlying_attn_backend, overrides={ "get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder, + "get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl, "get_kv_cache_shape": lambda num_blocks, block_size, num_kv_heads, @@ -175,6 +214,7 @@ def create_whisper_attention_backend_with_block_pooling( num_kv_heads // block_pool_size, head_size, ), # TODO: generalize to other backends + "forward_includes_kv_cache_update": True, }, )