Remove head_mask from Ultravox and Swin (#30764)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
@@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states, attention_mask, head_mask, output_attentions
|
||||
)
|
||||
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
return outputs
|
||||
@@ -339,18 +335,14 @@ class SwinStage(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_dimensions: tuple[int, int],
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
always_partition: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
height, width = input_dimensions
|
||||
for i, layer_module in enumerate(self.blocks):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_dimensions: tuple[int, int],
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
always_partition: bool | None = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
@@ -473,7 +461,6 @@ class SwinModel(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
head_mask: torch.FloatTensor | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
) -> tuple[torch.Tensor]:
|
||||
embedding_output, input_dimensions = self.embeddings(pixel_values)
|
||||
@@ -481,7 +468,6 @@ class SwinModel(nn.Module):
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
input_dimensions,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user