[Bugfix] Fix Voxtral streaming slot_mapping (#33073)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -105,6 +105,7 @@ def create_whisper_attention_backend_with_block_pooling(
|
|||||||
) -> type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
prefix = "WhisperCausalAttentionWithBlockPooling_"
|
prefix = "WhisperCausalAttentionWithBlockPooling_"
|
||||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
underlying_impl = underlying_attn_backend.get_impl_cls()
|
||||||
|
|
||||||
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
|
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -151,6 +152,43 @@ def create_whisper_attention_backend_with_block_pooling(
|
|||||||
common_prefix_len, new_common_attn_metadata, fast_build
|
common_prefix_len, new_common_attn_metadata, fast_build
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE: We need a custom impl so we can use the transformed slot_mapping
|
||||||
|
# computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of
|
||||||
|
# the one from `forward_context.slot_mapping` (gpu_model_runner).
|
||||||
|
# This follows the same pattern as CrossAttentionImpl.
|
||||||
|
class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc]
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
output: torch.Tensor | None = None,
|
||||||
|
output_scale: torch.Tensor | None = None,
|
||||||
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if (
|
||||||
|
not underlying_attn_backend.forward_includes_kv_cache_update
|
||||||
|
and attn_metadata is not None
|
||||||
|
):
|
||||||
|
self.do_kv_cache_update(
|
||||||
|
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().forward(
|
||||||
|
layer,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
output,
|
||||||
|
output_scale,
|
||||||
|
output_block_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{underlying_attn_backend} is not yet supported."
|
f"{underlying_attn_backend} is not yet supported."
|
||||||
@@ -163,6 +201,7 @@ def create_whisper_attention_backend_with_block_pooling(
|
|||||||
attention_backend_cls=underlying_attn_backend,
|
attention_backend_cls=underlying_attn_backend,
|
||||||
overrides={
|
overrides={
|
||||||
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
|
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
|
||||||
|
"get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl,
|
||||||
"get_kv_cache_shape": lambda num_blocks,
|
"get_kv_cache_shape": lambda num_blocks,
|
||||||
block_size,
|
block_size,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
@@ -175,6 +214,7 @@ def create_whisper_attention_backend_with_block_pooling(
|
|||||||
num_kv_heads // block_pool_size,
|
num_kv_heads // block_pool_size,
|
||||||
head_size,
|
head_size,
|
||||||
), # TODO: generalize to other backends
|
), # TODO: generalize to other backends
|
||||||
|
"forward_includes_kv_cache_update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user