Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,21 +7,21 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import SwinConfig
|
||||
from transformers.models.swin.modeling_swin import SwinEmbeddings
|
||||
from transformers.models.swin.modeling_swin import SwinEmbeddings, SwinPatchMerging
|
||||
from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer
|
||||
from transformers.models.swin.modeling_swin import SwinPatchMerging
|
||||
from transformers.pytorch_utils import meshgrid
|
||||
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
|
||||
class SwinSelfAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
@@ -35,35 +35,40 @@ class SwinSelfAttention(nn.Module):
|
||||
if dim % num_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({dim}) is not a multiple of the number of "
|
||||
f"attention heads ({num_heads})")
|
||||
f"attention heads ({num_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = num_heads
|
||||
self.attention_head_size = int(dim / num_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.window_size = (window_size if isinstance(window_size, Iterable)
|
||||
else (window_size, window_size))
|
||||
self.window_size = (
|
||||
window_size
|
||||
if isinstance(window_size, Iterable)
|
||||
else (window_size, window_size)
|
||||
)
|
||||
self.scale = self.attention_head_size**-0.5
|
||||
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(
|
||||
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
|
||||
num_heads))
|
||||
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
|
||||
)
|
||||
)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:,
|
||||
None, :]
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1)
|
||||
|
||||
self.relative_position_index = nn.Parameter(relative_position_index,
|
||||
requires_grad=False)
|
||||
self.relative_position_index = nn.Parameter(
|
||||
relative_position_index, requires_grad=False
|
||||
)
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=dim,
|
||||
@@ -75,19 +80,23 @@ class SwinSelfAttention(nn.Module):
|
||||
)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads,
|
||||
self.attention_head_size)
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def _get_rel_pos_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)]
|
||||
self.relative_position_index.view(-1)
|
||||
]
|
||||
relative_position_bias = relative_position_bias.view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1], -1)
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous()
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1,
|
||||
)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
||||
return relative_position_bias.unsqueeze(0)
|
||||
|
||||
def forward(
|
||||
@@ -110,38 +119,38 @@ class SwinSelfAttention(nn.Module):
|
||||
if attention_mask is not None:
|
||||
mask_shape = attention_mask.shape[0]
|
||||
attention_mask_expanded = attention_mask.view(
|
||||
1, mask_shape, 1, dim,
|
||||
dim).expand(batch_size // mask_shape, mask_shape,
|
||||
self.num_attention_heads, dim, dim)
|
||||
attention_scores = attention_scores + \
|
||||
attention_mask_expanded.unsqueeze(
|
||||
1).unsqueeze(0)
|
||||
attention_scores = attention_scores.view(-1,
|
||||
self.num_attention_heads,
|
||||
dim, dim)
|
||||
1, mask_shape, 1, dim, dim
|
||||
).expand(
|
||||
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
|
||||
)
|
||||
attention_scores = attention_scores + attention_mask_expanded.unsqueeze(
|
||||
1
|
||||
).unsqueeze(0)
|
||||
attention_scores = attention_scores.view(
|
||||
-1, self.num_attention_heads, dim, dim
|
||||
)
|
||||
|
||||
context_layer = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_scores,
|
||||
dropout_p=0.,
|
||||
dropout_p=0.0,
|
||||
)
|
||||
attention_probs = None
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||
self.all_head_size, )
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer,
|
||||
attention_probs) if output_attentions else (context_layer, )
|
||||
outputs = (
|
||||
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SwinSelfOutput(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
@@ -157,33 +166,36 @@ class SwinSelfOutput(nn.Module):
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwinAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self = SwinSelfAttention(config,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self")
|
||||
self.output = SwinSelfOutput(config,
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
self.self = SwinSelfAttention(
|
||||
config,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self",
|
||||
)
|
||||
self.output = SwinSelfOutput(
|
||||
config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
|
||||
)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def forward(
|
||||
@@ -193,25 +205,29 @@ class SwinAttention(nn.Module):
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(hidden_states, attention_mask, head_mask,
|
||||
output_attentions)
|
||||
self_outputs = self.self(
|
||||
hidden_states, attention_mask, head_mask, output_attentions
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output, ) + self_outputs[1:]
|
||||
outputs = (attention_output,) + self_outputs[1:]
|
||||
return outputs
|
||||
|
||||
|
||||
class SwinIntermediate(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense = ColumnParallelLinear(dim,
|
||||
int(config.mlp_ratio * dim),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
self.dense = ColumnParallelLinear(
|
||||
dim,
|
||||
int(config.mlp_ratio * dim),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
self.intermediate_act_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -221,17 +237,20 @@ class SwinIntermediate(nn.Module):
|
||||
|
||||
|
||||
class SwinOutput(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dense = RowParallelLinear(int(config.mlp_ratio * dim),
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense")
|
||||
self.dense = RowParallelLinear(
|
||||
int(config.mlp_ratio * dim),
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense(hidden_states)
|
||||
@@ -239,7 +258,6 @@ class SwinOutput(nn.Module):
|
||||
|
||||
|
||||
class SwinLayer(HFSwinLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
@@ -260,24 +278,23 @@ class SwinLayer(HFSwinLayer):
|
||||
shift_size=shift_size,
|
||||
)
|
||||
|
||||
self.attention = SwinAttention(config,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=self.window_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
self.intermediate = SwinIntermediate(config,
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.intermediate")
|
||||
self.output = SwinOutput(config,
|
||||
dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output")
|
||||
self.attention = SwinAttention(
|
||||
config,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=self.window_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
self.intermediate = SwinIntermediate(
|
||||
config, dim, quant_config=quant_config, prefix=f"{prefix}.intermediate"
|
||||
)
|
||||
self.output = SwinOutput(
|
||||
config, dim, quant_config=quant_config, prefix=f"{prefix}.output"
|
||||
)
|
||||
|
||||
|
||||
class SwinStage(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
@@ -293,24 +310,27 @@ class SwinStage(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinLayer(config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
drop_path_rate=drop_path[layer_idx],
|
||||
shift_size=0 if
|
||||
(layer_idx % 2 == 0) else config.window_size // 2,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}")
|
||||
for layer_idx in range(depth)
|
||||
])
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
SwinLayer(
|
||||
config=config,
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
drop_path_rate=drop_path[layer_idx],
|
||||
shift_size=0 if (layer_idx % 2 == 0) else config.window_size // 2,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution,
|
||||
dim=dim,
|
||||
norm_layer=nn.LayerNorm)
|
||||
self.downsample = downsample(
|
||||
input_resolution, dim=dim, norm_layer=nn.LayerNorm
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
@@ -328,25 +348,31 @@ class SwinStage(nn.Module):
|
||||
for i, layer_module in enumerate(self.blocks):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions,
|
||||
layer_head_mask, output_attentions,
|
||||
always_partition)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_before_downsampling = hidden_states
|
||||
if self.downsample is not None:
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width +
|
||||
1) // 2
|
||||
output_dimensions = (height, width, height_downsampled,
|
||||
width_downsampled)
|
||||
hidden_states = self.downsample(hidden_states_before_downsampling,
|
||||
input_dimensions)
|
||||
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
|
||||
output_dimensions = (height, width, height_downsampled, width_downsampled)
|
||||
hidden_states = self.downsample(
|
||||
hidden_states_before_downsampling, input_dimensions
|
||||
)
|
||||
else:
|
||||
output_dimensions = (height, width, height, width)
|
||||
|
||||
stage_outputs = (hidden_states, hidden_states_before_downsampling,
|
||||
output_dimensions)
|
||||
stage_outputs = (
|
||||
hidden_states,
|
||||
hidden_states_before_downsampling,
|
||||
output_dimensions,
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
stage_outputs += layer_outputs[1:]
|
||||
@@ -354,7 +380,6 @@ class SwinStage(nn.Module):
|
||||
|
||||
|
||||
class SwinEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SwinConfig,
|
||||
@@ -366,24 +391,36 @@ class SwinEncoder(nn.Module):
|
||||
self.num_layers = len(config.depths)
|
||||
self.config = config
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(
|
||||
0, config.drop_path_rate, sum(config.depths), device="cpu")
|
||||
x.item()
|
||||
for x in torch.linspace(
|
||||
0, config.drop_path_rate, sum(config.depths), device="cpu"
|
||||
)
|
||||
]
|
||||
self.layers = nn.ModuleList([
|
||||
SwinStage(config=config,
|
||||
dim=int(config.embed_dim * 2**layer_idx),
|
||||
input_resolution=(grid_size[0] // (2**layer_idx),
|
||||
grid_size[1] // (2**layer_idx)),
|
||||
depth=config.depths[layer_idx],
|
||||
num_heads=config.num_heads[layer_idx],
|
||||
drop_path=dpr[sum(config.depths[:layer_idx]
|
||||
):sum(config.depths[:layer_idx + 1])],
|
||||
downsample=SwinPatchMerging if
|
||||
(layer_idx < self.num_layers - 1) else None,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
for layer_idx in range(self.num_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
SwinStage(
|
||||
config=config,
|
||||
dim=int(config.embed_dim * 2**layer_idx),
|
||||
input_resolution=(
|
||||
grid_size[0] // (2**layer_idx),
|
||||
grid_size[1] // (2**layer_idx),
|
||||
),
|
||||
depth=config.depths[layer_idx],
|
||||
num_heads=config.num_heads[layer_idx],
|
||||
drop_path=dpr[
|
||||
sum(config.depths[:layer_idx]) : sum(
|
||||
config.depths[: layer_idx + 1]
|
||||
)
|
||||
],
|
||||
downsample=SwinPatchMerging
|
||||
if (layer_idx < self.num_layers - 1)
|
||||
else None,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -396,9 +433,13 @@ class SwinEncoder(nn.Module):
|
||||
for i, layer_module in enumerate(self.layers):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
layer_outputs = layer_module(hidden_states, input_dimensions,
|
||||
layer_head_mask, output_attentions,
|
||||
always_partition)
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
output_dimensions = layer_outputs[2]
|
||||
@@ -420,13 +461,15 @@ class SwinModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_layers = len(config.depths)
|
||||
self.num_features = int(config.embed_dim * 2**(self.num_layers - 1))
|
||||
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
|
||||
|
||||
self.embeddings = SwinEmbeddings(config)
|
||||
self.encoder = SwinEncoder(config,
|
||||
self.embeddings.patch_grid,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.encoder = SwinEncoder(
|
||||
config,
|
||||
self.embeddings.patch_grid,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -445,8 +488,7 @@ class SwinModel(nn.Module):
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
("qkv", "query", "q"),
|
||||
("qkv", "key", "k"),
|
||||
@@ -456,8 +498,7 @@ class SwinModel(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -468,8 +509,7 @@ class SwinModel(nn.Module):
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user