Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,6 @@ norm_t = Union[tuple[float, float, float], torch.Tensor]
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, Iterable) and not isinstance(x, str):
|
||||
return tuple(x)
|
||||
@@ -45,7 +44,6 @@ to_ntuple = _ntuple
|
||||
|
||||
|
||||
class InputConditioner(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_scale: float,
|
||||
@@ -72,7 +70,6 @@ def _to_tensor(v: norm_t):
|
||||
|
||||
|
||||
class ClsToken(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ndim: int,
|
||||
@@ -91,12 +88,14 @@ class ClsToken(nn.Module):
|
||||
if num_registers:
|
||||
self.num_registers = num_registers
|
||||
elif register_multiple:
|
||||
self.num_registers = register_multiple - (num_tokens %
|
||||
register_multiple)
|
||||
self.num_registers = register_multiple - (
|
||||
num_tokens % register_multiple
|
||||
)
|
||||
|
||||
scale = ndim**-0.5
|
||||
self.token = nn.Parameter(
|
||||
torch.randn(num_tokens + self.num_registers, ndim) * scale)
|
||||
torch.randn(num_tokens + self.num_registers, ndim) * scale
|
||||
)
|
||||
|
||||
else:
|
||||
self.token = None
|
||||
@@ -108,16 +107,18 @@ class ClsToken(nn.Module):
|
||||
return x
|
||||
|
||||
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
|
||||
x = torch.cat([
|
||||
token,
|
||||
x,
|
||||
], dim=1)
|
||||
x = torch.cat(
|
||||
[
|
||||
token,
|
||||
x,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ViTPatchGenerator(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# config: PretrainedConfig,
|
||||
@@ -147,8 +148,8 @@ class ViTPatchGenerator(nn.Module):
|
||||
max_input_dims = (max_input_dims, max_input_dims)
|
||||
|
||||
max_input_dims = tuple(
|
||||
int(math.ceil(d / patch_size) * patch_size)
|
||||
for d in max_input_dims)
|
||||
int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims
|
||||
)
|
||||
|
||||
self.cpe_mode = max_input_dims != input_dims
|
||||
self.pos_dropout = pos_dropout
|
||||
@@ -167,15 +168,15 @@ class ViTPatchGenerator(nn.Module):
|
||||
self.max_input_dims = max_input_dims
|
||||
|
||||
self.im_to_patches = Im2Patches(patch_size)
|
||||
self.embedder = ViTPatchLinear(patch_size,
|
||||
embed_dim,
|
||||
bias=patch_bias,
|
||||
**factory)
|
||||
self.embedder = ViTPatchLinear(
|
||||
patch_size, embed_dim, bias=patch_bias, **factory
|
||||
)
|
||||
|
||||
if abs_pos:
|
||||
scale = embed_dim**-0.5
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
|
||||
torch.randn(1, self.num_patches, embed_dim, **factory) * scale
|
||||
)
|
||||
|
||||
self.cls_token = ClsToken(
|
||||
embed_dim,
|
||||
@@ -185,8 +186,9 @@ class ViTPatchGenerator(nn.Module):
|
||||
num_registers=num_registers,
|
||||
)
|
||||
|
||||
self.patch_normalizer = nn.LayerNorm(
|
||||
embed_dim) if normalize_patches else nn.Identity()
|
||||
self.patch_normalizer = (
|
||||
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
patches = self.embed_patches(x)
|
||||
@@ -221,42 +223,48 @@ class ViTPatchGenerator(nn.Module):
|
||||
if src_embed.shape != targ_embed.shape:
|
||||
src_size = int(math.sqrt(src_embed.shape[1]))
|
||||
|
||||
assert src_size**2 == src_embed.shape[
|
||||
1], 'Unable to interpolate non-square embedding'
|
||||
assert src_size**2 == src_embed.shape[1], (
|
||||
"Unable to interpolate non-square embedding"
|
||||
)
|
||||
|
||||
src_embed = rearrange(src_embed,
|
||||
'b (h w) c -> b c h w',
|
||||
h=src_size,
|
||||
w=src_size)
|
||||
src_embed = F.interpolate(src_embed,
|
||||
size=(self.num_rows, self.num_cols),
|
||||
mode='bicubic',
|
||||
align_corners=True,
|
||||
antialias=False)
|
||||
src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
|
||||
src_embed = rearrange(
|
||||
src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size
|
||||
)
|
||||
src_embed = F.interpolate(
|
||||
src_embed,
|
||||
size=(self.num_rows, self.num_cols),
|
||||
mode="bicubic",
|
||||
align_corners=True,
|
||||
antialias=False,
|
||||
)
|
||||
src_embed = rearrange(src_embed, "b c h w -> b (h w) c")
|
||||
targ_embed.data.copy_(src_embed)
|
||||
|
||||
def _load_projection(self, src_proj_weight: torch.Tensor,
|
||||
targ_proj_weight: torch.Tensor):
|
||||
def _load_projection(
|
||||
self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor
|
||||
):
|
||||
if src_proj_weight.shape != targ_proj_weight.shape:
|
||||
src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
|
||||
|
||||
assert (src_patch_size**2) * 3 == src_proj_weight.shape[
|
||||
1], 'Unable to interpolate non-square patch size'
|
||||
assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], (
|
||||
"Unable to interpolate non-square patch size"
|
||||
)
|
||||
|
||||
src_proj_weight = rearrange(src_proj_weight,
|
||||
'b (c h w) -> b c h w',
|
||||
c=3,
|
||||
h=src_patch_size,
|
||||
w=src_patch_size)
|
||||
src_proj_weight = F.interpolate(src_proj_weight,
|
||||
size=(self.patch_size,
|
||||
self.patch_size),
|
||||
mode='bicubic',
|
||||
align_corners=True,
|
||||
antialias=False)
|
||||
src_proj_weight = rearrange(src_proj_weight,
|
||||
'b c h w -> b (c h w)')
|
||||
src_proj_weight = rearrange(
|
||||
src_proj_weight,
|
||||
"b (c h w) -> b c h w",
|
||||
c=3,
|
||||
h=src_patch_size,
|
||||
w=src_patch_size,
|
||||
)
|
||||
src_proj_weight = F.interpolate(
|
||||
src_proj_weight,
|
||||
size=(self.patch_size, self.patch_size),
|
||||
mode="bicubic",
|
||||
align_corners=True,
|
||||
antialias=False,
|
||||
)
|
||||
src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)")
|
||||
targ_proj_weight.data.copy_(src_proj_weight)
|
||||
|
||||
def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -276,11 +284,12 @@ class ViTPatchGenerator(nn.Module):
|
||||
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
|
||||
|
||||
if self.training and self.pos_dropout > 0:
|
||||
keeps = torch.rand(patches.shape[0],
|
||||
1,
|
||||
1,
|
||||
dtype=pos_enc.dtype,
|
||||
device=pos_enc.device) > self.pos_dropout
|
||||
keeps = (
|
||||
torch.rand(
|
||||
patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device
|
||||
)
|
||||
> self.pos_dropout
|
||||
)
|
||||
pos_enc_drop = torch.where(keeps, pos_enc, 0)
|
||||
else:
|
||||
pos_enc_drop = pos_enc
|
||||
@@ -303,56 +312,58 @@ class ViTPatchGenerator(nn.Module):
|
||||
if patch_idxs is None:
|
||||
return pos_embed
|
||||
|
||||
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(
|
||||
-1, -1, pos_embed.shape[-1])
|
||||
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
|
||||
|
||||
pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1),
|
||||
dim=1,
|
||||
index=exp_patch_idxs)
|
||||
pos_embed = torch.gather(
|
||||
pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int,
|
||||
int]):
|
||||
def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]):
|
||||
if (self.num_rows, self.num_cols) == input_dims:
|
||||
return self.pos_embed
|
||||
|
||||
pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(
|
||||
0, 3, 1, 2
|
||||
)
|
||||
|
||||
def window_select(pos_embed):
|
||||
if input_dims[0] < pos_embed.shape[-2]:
|
||||
pos_embed = pos_embed[..., :input_dims[0], :]
|
||||
pos_embed = pos_embed[..., : input_dims[0], :]
|
||||
if input_dims[1] < pos_embed.shape[-1]:
|
||||
pos_embed = pos_embed[..., :, :input_dims[1]]
|
||||
pos_embed = pos_embed[..., :, : input_dims[1]]
|
||||
return pos_embed
|
||||
|
||||
if self.cpe_mode:
|
||||
if self.training:
|
||||
min_scale = math.sqrt(0.1)
|
||||
scale = torch.rand(batch_size, 1, 1, device=pos_embed.device
|
||||
) * (1 - min_scale) + min_scale
|
||||
scale = (
|
||||
torch.rand(batch_size, 1, 1, device=pos_embed.device)
|
||||
* (1 - min_scale)
|
||||
+ min_scale
|
||||
)
|
||||
aspect_min = math.log(3 / 4)
|
||||
aspect_max = -aspect_min
|
||||
aspect = torch.exp(
|
||||
torch.rand(batch_size, 1, 1, device=pos_embed.device) *
|
||||
(aspect_max - aspect_min) + aspect_min)
|
||||
torch.rand(batch_size, 1, 1, device=pos_embed.device)
|
||||
* (aspect_max - aspect_min)
|
||||
+ aspect_min
|
||||
)
|
||||
|
||||
scale_x = scale * aspect
|
||||
scale_y = scale * (1 / aspect)
|
||||
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
|
||||
|
||||
pos_xy = torch.rand(
|
||||
batch_size, 1, 1, 2,
|
||||
device=pos_embed.device) * (1 - scale_xy)
|
||||
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (
|
||||
1 - scale_xy
|
||||
)
|
||||
|
||||
lin_x = torch.linspace(
|
||||
0, 1, steps=input_dims[1],
|
||||
device=pos_embed.device)[None, None].expand(
|
||||
batch_size, input_dims[0], -1)
|
||||
0, 1, steps=input_dims[1], device=pos_embed.device
|
||||
)[None, None].expand(batch_size, input_dims[0], -1)
|
||||
lin_y = torch.linspace(
|
||||
0, 1, steps=input_dims[0],
|
||||
device=pos_embed.device)[None, :, None].expand(
|
||||
batch_size, -1, input_dims[1])
|
||||
0, 1, steps=input_dims[0], device=pos_embed.device
|
||||
)[None, :, None].expand(batch_size, -1, input_dims[1])
|
||||
|
||||
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
|
||||
|
||||
@@ -364,26 +375,27 @@ class ViTPatchGenerator(nn.Module):
|
||||
pos_embed = F.grid_sample(
|
||||
pos_embed.float().expand(batch_size, -1, -1, -1),
|
||||
grid=grid_xy,
|
||||
mode='bilinear',
|
||||
padding_mode='zeros',
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=True,
|
||||
).to(pos_embed.dtype)
|
||||
else:
|
||||
max_dim = max(input_dims)
|
||||
pos_embed = F.interpolate(pos_embed.float(),
|
||||
size=(max_dim, max_dim),
|
||||
align_corners=True,
|
||||
mode='bilinear').to(pos_embed.dtype)
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed.float(),
|
||||
size=(max_dim, max_dim),
|
||||
align_corners=True,
|
||||
mode="bilinear",
|
||||
).to(pos_embed.dtype)
|
||||
|
||||
pos_embed = window_select(pos_embed)
|
||||
else:
|
||||
pos_embed = window_select(pos_embed)
|
||||
|
||||
if pos_embed.shape[-2:] != input_dims:
|
||||
pos_embed = F.interpolate(pos_embed.float(),
|
||||
size=input_dims,
|
||||
align_corners=True,
|
||||
mode='bilinear').to(pos_embed.dtype)
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear"
|
||||
).to(pos_embed.dtype)
|
||||
|
||||
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
|
||||
|
||||
@@ -391,7 +403,6 @@ class ViTPatchGenerator(nn.Module):
|
||||
|
||||
|
||||
class Im2Patches(nn.Module):
|
||||
|
||||
def __init__(self, patch_size: int):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
@@ -406,7 +417,7 @@ class Im2Patches(nn.Module):
|
||||
px = x.shape[-1] // self.patch_size
|
||||
patches = rearrange(
|
||||
x,
|
||||
'b c (py yy) (px xx) -> b (py px) (c yy xx)',
|
||||
"b c (py yy) (px xx) -> b (py px) (c yy xx)",
|
||||
py=py,
|
||||
yy=self.patch_size,
|
||||
px=px,
|
||||
@@ -416,12 +427,7 @@ class Im2Patches(nn.Module):
|
||||
|
||||
|
||||
class ViTPatchLinear(nn.Linear):
|
||||
|
||||
def __init__(self,
|
||||
patch_size: int,
|
||||
embed_dim: int,
|
||||
bias: bool = False,
|
||||
**factory):
|
||||
def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
|
||||
super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory)
|
||||
self.patch_size = patch_size
|
||||
|
||||
@@ -444,16 +450,19 @@ class RadioInternVisionModel(nn.Module):
|
||||
|
||||
self.config = config
|
||||
self.img_size, self.grid_size, self.num_patches = self._init_img_size(
|
||||
to_2tuple(config.patch_size), config.image_size)
|
||||
to_2tuple(config.patch_size), config.image_size
|
||||
)
|
||||
max_img_size = int(
|
||||
round(config.max_img_size / config.patch_size) * config.patch_size)
|
||||
round(config.max_img_size / config.patch_size) * config.patch_size
|
||||
)
|
||||
self.patch_generator = ViTPatchGenerator(
|
||||
config.patch_size,
|
||||
config.hidden_size,
|
||||
input_dims=self.img_size,
|
||||
max_input_dims=max_img_size,
|
||||
cls_token=True,
|
||||
register_multiple=config.reg_tokens)
|
||||
register_multiple=config.reg_tokens,
|
||||
)
|
||||
|
||||
self.encoder = InternVisionEncoder(
|
||||
config=config,
|
||||
@@ -463,8 +472,7 @@ class RadioInternVisionModel(nn.Module):
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
def _init_img_size(self, patch_size, img_size: Union[int, tuple[int,
|
||||
int]]):
|
||||
def _init_img_size(self, patch_size, img_size: Union[int, tuple[int, int]]):
|
||||
if img_size is None:
|
||||
return None, None, None
|
||||
img_size = to_2tuple(img_size)
|
||||
@@ -509,7 +517,8 @@ class RadioModel(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -534,7 +543,7 @@ class RadioModel(nn.Module):
|
||||
# Skip non-radio weights
|
||||
continue
|
||||
|
||||
sub = name[len("radio_model."):] # drop "radio_model." prefix
|
||||
sub = name[len("radio_model.") :] # drop "radio_model." prefix
|
||||
|
||||
# Skip buffers not used in vLLM
|
||||
if sub in {"summary_idxs"}:
|
||||
@@ -553,15 +562,13 @@ class RadioModel(nn.Module):
|
||||
layer_idx = parts[2]
|
||||
suffix = ".".join(parts[3:])
|
||||
# Skip layer-scale entries that vLLM doesn't use
|
||||
if suffix in {"ls1", "ls2"} or suffix.startswith(
|
||||
("ls1.", "ls2.")):
|
||||
if suffix in {"ls1", "ls2"} or suffix.startswith(("ls1.", "ls2.")):
|
||||
continue
|
||||
vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}"
|
||||
|
||||
if vllm_key and vllm_key in params_dict:
|
||||
param = params_dict[vllm_key]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(vllm_key)
|
||||
|
||||
@@ -571,6 +578,6 @@ class RadioModel(nn.Module):
|
||||
# Remove CLS + REGISTERS tokens
|
||||
patch_gen = getattr(self.model, "patch_generator", None)
|
||||
if patch_gen is not None:
|
||||
all_feat = y[:, patch_gen.num_skip:]
|
||||
all_feat = y[:, patch_gen.num_skip :]
|
||||
|
||||
return all_feat
|
||||
|
||||
Reference in New Issue
Block a user