Remove head_mask from Ultravox and Swin (#30764)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.FloatTensor | None = None,
|
attention_mask: torch.FloatTensor | None = None,
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = False,
|
output_attentions: bool | None = False,
|
||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
batch_size, dim, num_channels = hidden_states.shape
|
batch_size, dim, num_channels = hidden_states.shape
|
||||||
@@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: torch.FloatTensor | None = None,
|
attention_mask: torch.FloatTensor | None = None,
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = False,
|
output_attentions: bool | None = False,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
|
||||||
hidden_states, attention_mask, head_mask, output_attentions
|
|
||||||
)
|
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
outputs = (attention_output,) + self_outputs[1:]
|
outputs = (attention_output,) + self_outputs[1:]
|
||||||
return outputs
|
return outputs
|
||||||
@@ -339,18 +335,14 @@ class SwinStage(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_dimensions: tuple[int, int],
|
input_dimensions: tuple[int, int],
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = False,
|
output_attentions: bool | None = False,
|
||||||
always_partition: bool | None = False,
|
always_partition: bool | None = False,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
height, width = input_dimensions
|
height, width = input_dimensions
|
||||||
for i, layer_module in enumerate(self.blocks):
|
for i, layer_module in enumerate(self.blocks):
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
input_dimensions,
|
input_dimensions,
|
||||||
layer_head_mask,
|
|
||||||
output_attentions,
|
output_attentions,
|
||||||
always_partition,
|
always_partition,
|
||||||
)
|
)
|
||||||
@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_dimensions: tuple[int, int],
|
input_dimensions: tuple[int, int],
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = False,
|
output_attentions: bool | None = False,
|
||||||
always_partition: bool | None = False,
|
always_partition: bool | None = False,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
for i, layer_module in enumerate(self.layers):
|
for i, layer_module in enumerate(self.layers):
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
input_dimensions,
|
input_dimensions,
|
||||||
layer_head_mask,
|
|
||||||
output_attentions,
|
output_attentions,
|
||||||
always_partition,
|
always_partition,
|
||||||
)
|
)
|
||||||
@@ -473,7 +461,6 @@ class SwinModel(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
head_mask: torch.FloatTensor | None = None,
|
|
||||||
output_attentions: bool | None = None,
|
output_attentions: bool | None = None,
|
||||||
) -> tuple[torch.Tensor]:
|
) -> tuple[torch.Tensor]:
|
||||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||||
@@ -481,7 +468,6 @@ class SwinModel(nn.Module):
|
|||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
input_dimensions,
|
input_dimensions,
|
||||||
head_mask=head_mask,
|
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
"""PyTorch Ultravox model."""
|
"""PyTorch Ultravox model."""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Annotated, Any, Literal, TypeAlias
|
from typing import Annotated, Any, Literal, TypeAlias
|
||||||
@@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
|
|||||||
)
|
)
|
||||||
hidden_states = hidden_states + positions
|
hidden_states = hidden_states + positions
|
||||||
|
|
||||||
|
# Backward compatibility for Transformers v4 where layer_head_mask
|
||||||
|
# was a required argument for WhisperEncoderLayer.forward
|
||||||
|
kwargs = {}
|
||||||
|
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
|
||||||
|
kwargs["layer_head_mask"] = None
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=extended_attention_mask,
|
attention_mask=extended_attention_mask,
|
||||||
layer_head_mask=None,
|
**kwargs,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
@@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
|||||||
|
|
||||||
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
|
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
|
||||||
|
|
||||||
|
# Backward compatibility for Transformers v4 where layer_head_mask
|
||||||
|
# was a required argument for WhisperEncoderLayer.forward
|
||||||
|
kwargs = {}
|
||||||
|
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
|
||||||
|
kwargs["layer_head_mask"] = None
|
||||||
|
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
layer_outputs = encoder_layer(
|
layer_outputs = encoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
layer_head_mask=None,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user