[Bugfix] Fix Voxtral streaming slot_mapping (#33073)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -105,6 +105,7 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "WhisperCausalAttentionWithBlockPooling_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
underlying_impl = underlying_attn_backend.get_impl_cls()
|
||||
|
||||
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
|
||||
def __init__(
|
||||
@@ -151,6 +152,43 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
common_prefix_len, new_common_attn_metadata, fast_build
|
||||
)
|
||||
|
||||
# NOTE: We need a custom impl so we can use the transformed slot_mapping
|
||||
# computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of
|
||||
# the one from `forward_context.slot_mapping` (gpu_model_runner).
|
||||
# This follows the same pattern as CrossAttentionImpl.
|
||||
class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc]
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
not underlying_attn_backend.forward_includes_kv_cache_update
|
||||
and attn_metadata is not None
|
||||
):
|
||||
self.do_kv_cache_update(
|
||||
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output,
|
||||
output_scale,
|
||||
output_block_scale,
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||
raise NotImplementedError(
|
||||
f"{underlying_attn_backend} is not yet supported."
|
||||
@@ -163,6 +201,7 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
overrides={
|
||||
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
|
||||
"get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl,
|
||||
"get_kv_cache_shape": lambda num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
@@ -175,6 +214,7 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
num_kv_heads // block_pool_size,
|
||||
head_size,
|
||||
), # TODO: generalize to other backends
|
||||
"forward_includes_kv_cache_update": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user