Remove head_mask from Ultravox and Swin (#30764)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-12-16 16:02:41 +00:00
committed by GitHub
parent af506fd76a
commit 0b0acc758e
2 changed files with 16 additions and 17 deletions

View File

@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None, attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
@@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None, attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(hidden_states, attention_mask, output_attentions)
hidden_states, attention_mask, head_mask, output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
@@ -339,18 +335,14 @@ class SwinStage(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_dimensions: tuple[int, int], input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
always_partition: bool | None = False, always_partition: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
height, width = input_dimensions height, width = input_dimensions
for i, layer_module in enumerate(self.blocks): for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
input_dimensions, input_dimensions,
layer_head_mask,
output_attentions, output_attentions,
always_partition, always_partition,
) )
@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_dimensions: tuple[int, int], input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
always_partition: bool | None = False, always_partition: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers): for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
input_dimensions, input_dimensions,
layer_head_mask,
output_attentions, output_attentions,
always_partition, always_partition,
) )
@@ -473,7 +461,6 @@ class SwinModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor | None = None, pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None, output_attentions: bool | None = None,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values) embedding_output, input_dimensions = self.embeddings(pixel_values)
@@ -481,7 +468,6 @@ class SwinModel(nn.Module):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
input_dimensions, input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )

View File

@@ -5,6 +5,7 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import copy import copy
import inspect
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
@@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
) )
hidden_states = hidden_states + positions hidden_states = hidden_states + positions
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for layer in self.layers: for layer in self.layers:
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
layer_head_mask=None, **kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states) attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=None, **kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]