[Bugfix] Fix Voxtral streaming slot_mapping (#33073)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2026-01-26 19:40:40 +01:00
committed by GitHub
parent 67fe677c53
commit 19ab0f7ce5

View File

@@ -105,6 +105,7 @@ def create_whisper_attention_backend_with_block_pooling(
) -> type[AttentionBackend]:
prefix = "WhisperCausalAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__(
@@ -151,6 +152,43 @@ def create_whisper_attention_backend_with_block_pooling(
common_prefix_len, new_common_attn_metadata, fast_build
)
# NOTE: We need a custom impl so we can use the transformed slot_mapping
# computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of
# the one from `forward_context.slot_mapping` (gpu_model_runner).
# This follows the same pattern as CrossAttentionImpl.
class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
@@ -163,6 +201,7 @@ def create_whisper_attention_backend_with_block_pooling(
attention_backend_cls=underlying_attn_backend,
overrides={
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
"get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl,
"get_kv_cache_shape": lambda num_blocks,
block_size,
num_kv_heads,
@@ -175,6 +214,7 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads // block_pool_size,
head_size,
), # TODO: generalize to other backends
"forward_includes_kv_cache_update": True,
},
)