Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -28,7 +28,6 @@ norm_t = Union[tuple[float, float, float], torch.Tensor]
def _ntuple(n):
def parse(x):
if isinstance(x, Iterable) and not isinstance(x, str):
return tuple(x)
@@ -45,7 +44,6 @@ to_ntuple = _ntuple
class InputConditioner(nn.Module):
def __init__(
self,
input_scale: float,
@@ -72,7 +70,6 @@ def _to_tensor(v: norm_t):
class ClsToken(nn.Module):
def __init__(
self,
ndim: int,
@@ -91,12 +88,14 @@ class ClsToken(nn.Module):
if num_registers:
self.num_registers = num_registers
elif register_multiple:
self.num_registers = register_multiple - (num_tokens %
register_multiple)
self.num_registers = register_multiple - (
num_tokens % register_multiple
)
scale = ndim**-0.5
self.token = nn.Parameter(
torch.randn(num_tokens + self.num_registers, ndim) * scale)
torch.randn(num_tokens + self.num_registers, ndim) * scale
)
else:
self.token = None
@@ -108,16 +107,18 @@ class ClsToken(nn.Module):
return x
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
x = torch.cat([
token,
x,
], dim=1)
x = torch.cat(
[
token,
x,
],
dim=1,
)
return x
class ViTPatchGenerator(nn.Module):
def __init__(
self,
# config: PretrainedConfig,
@@ -147,8 +148,8 @@ class ViTPatchGenerator(nn.Module):
max_input_dims = (max_input_dims, max_input_dims)
max_input_dims = tuple(
int(math.ceil(d / patch_size) * patch_size)
for d in max_input_dims)
int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims
)
self.cpe_mode = max_input_dims != input_dims
self.pos_dropout = pos_dropout
@@ -167,15 +168,15 @@ class ViTPatchGenerator(nn.Module):
self.max_input_dims = max_input_dims
self.im_to_patches = Im2Patches(patch_size)
self.embedder = ViTPatchLinear(patch_size,
embed_dim,
bias=patch_bias,
**factory)
self.embedder = ViTPatchLinear(
patch_size, embed_dim, bias=patch_bias, **factory
)
if abs_pos:
scale = embed_dim**-0.5
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
torch.randn(1, self.num_patches, embed_dim, **factory) * scale
)
self.cls_token = ClsToken(
embed_dim,
@@ -185,8 +186,9 @@ class ViTPatchGenerator(nn.Module):
num_registers=num_registers,
)
self.patch_normalizer = nn.LayerNorm(
embed_dim) if normalize_patches else nn.Identity()
self.patch_normalizer = (
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
patches = self.embed_patches(x)
@@ -221,42 +223,48 @@ class ViTPatchGenerator(nn.Module):
if src_embed.shape != targ_embed.shape:
src_size = int(math.sqrt(src_embed.shape[1]))
assert src_size**2 == src_embed.shape[
1], 'Unable to interpolate non-square embedding'
assert src_size**2 == src_embed.shape[1], (
"Unable to interpolate non-square embedding"
)
src_embed = rearrange(src_embed,
'b (h w) c -> b c h w',
h=src_size,
w=src_size)
src_embed = F.interpolate(src_embed,
size=(self.num_rows, self.num_cols),
mode='bicubic',
align_corners=True,
antialias=False)
src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
src_embed = rearrange(
src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size
)
src_embed = F.interpolate(
src_embed,
size=(self.num_rows, self.num_cols),
mode="bicubic",
align_corners=True,
antialias=False,
)
src_embed = rearrange(src_embed, "b c h w -> b (h w) c")
targ_embed.data.copy_(src_embed)
def _load_projection(self, src_proj_weight: torch.Tensor,
targ_proj_weight: torch.Tensor):
def _load_projection(
self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor
):
if src_proj_weight.shape != targ_proj_weight.shape:
src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
assert (src_patch_size**2) * 3 == src_proj_weight.shape[
1], 'Unable to interpolate non-square patch size'
assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], (
"Unable to interpolate non-square patch size"
)
src_proj_weight = rearrange(src_proj_weight,
'b (c h w) -> b c h w',
c=3,
h=src_patch_size,
w=src_patch_size)
src_proj_weight = F.interpolate(src_proj_weight,
size=(self.patch_size,
self.patch_size),
mode='bicubic',
align_corners=True,
antialias=False)
src_proj_weight = rearrange(src_proj_weight,
'b c h w -> b (c h w)')
src_proj_weight = rearrange(
src_proj_weight,
"b (c h w) -> b c h w",
c=3,
h=src_patch_size,
w=src_patch_size,
)
src_proj_weight = F.interpolate(
src_proj_weight,
size=(self.patch_size, self.patch_size),
mode="bicubic",
align_corners=True,
antialias=False,
)
src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)")
targ_proj_weight.data.copy_(src_proj_weight)
def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
@@ -276,11 +284,12 @@ class ViTPatchGenerator(nn.Module):
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
if self.training and self.pos_dropout > 0:
keeps = torch.rand(patches.shape[0],
1,
1,
dtype=pos_enc.dtype,
device=pos_enc.device) > self.pos_dropout
keeps = (
torch.rand(
patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device
)
> self.pos_dropout
)
pos_enc_drop = torch.where(keeps, pos_enc, 0)
else:
pos_enc_drop = pos_enc
@@ -303,56 +312,58 @@ class ViTPatchGenerator(nn.Module):
if patch_idxs is None:
return pos_embed
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(
-1, -1, pos_embed.shape[-1])
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1),
dim=1,
index=exp_patch_idxs)
pos_embed = torch.gather(
pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
)
return pos_embed
def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int,
int]):
def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]):
if (self.num_rows, self.num_cols) == input_dims:
return self.pos_embed
pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols,
-1).permute(0, 3, 1, 2)
pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(
0, 3, 1, 2
)
def window_select(pos_embed):
if input_dims[0] < pos_embed.shape[-2]:
pos_embed = pos_embed[..., :input_dims[0], :]
pos_embed = pos_embed[..., : input_dims[0], :]
if input_dims[1] < pos_embed.shape[-1]:
pos_embed = pos_embed[..., :, :input_dims[1]]
pos_embed = pos_embed[..., :, : input_dims[1]]
return pos_embed
if self.cpe_mode:
if self.training:
min_scale = math.sqrt(0.1)
scale = torch.rand(batch_size, 1, 1, device=pos_embed.device
) * (1 - min_scale) + min_scale
scale = (
torch.rand(batch_size, 1, 1, device=pos_embed.device)
* (1 - min_scale)
+ min_scale
)
aspect_min = math.log(3 / 4)
aspect_max = -aspect_min
aspect = torch.exp(
torch.rand(batch_size, 1, 1, device=pos_embed.device) *
(aspect_max - aspect_min) + aspect_min)
torch.rand(batch_size, 1, 1, device=pos_embed.device)
* (aspect_max - aspect_min)
+ aspect_min
)
scale_x = scale * aspect
scale_y = scale * (1 / aspect)
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
pos_xy = torch.rand(
batch_size, 1, 1, 2,
device=pos_embed.device) * (1 - scale_xy)
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (
1 - scale_xy
)
lin_x = torch.linspace(
0, 1, steps=input_dims[1],
device=pos_embed.device)[None, None].expand(
batch_size, input_dims[0], -1)
0, 1, steps=input_dims[1], device=pos_embed.device
)[None, None].expand(batch_size, input_dims[0], -1)
lin_y = torch.linspace(
0, 1, steps=input_dims[0],
device=pos_embed.device)[None, :, None].expand(
batch_size, -1, input_dims[1])
0, 1, steps=input_dims[0], device=pos_embed.device
)[None, :, None].expand(batch_size, -1, input_dims[1])
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
@@ -364,26 +375,27 @@ class ViTPatchGenerator(nn.Module):
pos_embed = F.grid_sample(
pos_embed.float().expand(batch_size, -1, -1, -1),
grid=grid_xy,
mode='bilinear',
padding_mode='zeros',
mode="bilinear",
padding_mode="zeros",
align_corners=True,
).to(pos_embed.dtype)
else:
max_dim = max(input_dims)
pos_embed = F.interpolate(pos_embed.float(),
size=(max_dim, max_dim),
align_corners=True,
mode='bilinear').to(pos_embed.dtype)
pos_embed = F.interpolate(
pos_embed.float(),
size=(max_dim, max_dim),
align_corners=True,
mode="bilinear",
).to(pos_embed.dtype)
pos_embed = window_select(pos_embed)
else:
pos_embed = window_select(pos_embed)
if pos_embed.shape[-2:] != input_dims:
pos_embed = F.interpolate(pos_embed.float(),
size=input_dims,
align_corners=True,
mode='bilinear').to(pos_embed.dtype)
pos_embed = F.interpolate(
pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear"
).to(pos_embed.dtype)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
@@ -391,7 +403,6 @@ class ViTPatchGenerator(nn.Module):
class Im2Patches(nn.Module):
def __init__(self, patch_size: int):
super().__init__()
self.patch_size = patch_size
@@ -406,7 +417,7 @@ class Im2Patches(nn.Module):
px = x.shape[-1] // self.patch_size
patches = rearrange(
x,
'b c (py yy) (px xx) -> b (py px) (c yy xx)',
"b c (py yy) (px xx) -> b (py px) (c yy xx)",
py=py,
yy=self.patch_size,
px=px,
@@ -416,12 +427,7 @@ class Im2Patches(nn.Module):
class ViTPatchLinear(nn.Linear):
def __init__(self,
patch_size: int,
embed_dim: int,
bias: bool = False,
**factory):
def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory)
self.patch_size = patch_size
@@ -444,16 +450,19 @@ class RadioInternVisionModel(nn.Module):
self.config = config
self.img_size, self.grid_size, self.num_patches = self._init_img_size(
to_2tuple(config.patch_size), config.image_size)
to_2tuple(config.patch_size), config.image_size
)
max_img_size = int(
round(config.max_img_size / config.patch_size) * config.patch_size)
round(config.max_img_size / config.patch_size) * config.patch_size
)
self.patch_generator = ViTPatchGenerator(
config.patch_size,
config.hidden_size,
input_dims=self.img_size,
max_input_dims=max_img_size,
cls_token=True,
register_multiple=config.reg_tokens)
register_multiple=config.reg_tokens,
)
self.encoder = InternVisionEncoder(
config=config,
@@ -463,8 +472,7 @@ class RadioInternVisionModel(nn.Module):
prefix=f"{prefix}.encoder",
)
def _init_img_size(self, patch_size, img_size: Union[int, tuple[int,
int]]):
def _init_img_size(self, patch_size, img_size: Union[int, tuple[int, int]]):
if img_size is None:
return None, None, None
img_size = to_2tuple(img_size)
@@ -509,7 +517,8 @@ class RadioModel(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=prefix)
prefix=prefix,
)
def forward(
self,
@@ -534,7 +543,7 @@ class RadioModel(nn.Module):
# Skip non-radio weights
continue
sub = name[len("radio_model."):] # drop "radio_model." prefix
sub = name[len("radio_model.") :] # drop "radio_model." prefix
# Skip buffers not used in vLLM
if sub in {"summary_idxs"}:
@@ -553,15 +562,13 @@ class RadioModel(nn.Module):
layer_idx = parts[2]
suffix = ".".join(parts[3:])
# Skip layer-scale entries that vLLM doesn't use
if suffix in {"ls1", "ls2"} or suffix.startswith(
("ls1.", "ls2.")):
if suffix in {"ls1", "ls2"} or suffix.startswith(("ls1.", "ls2.")):
continue
vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}"
if vllm_key and vllm_key in params_dict:
param = params_dict[vllm_key]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
loaded_params.add(vllm_key)
@@ -571,6 +578,6 @@ class RadioModel(nn.Module):
# Remove CLS + REGISTERS tokens
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
all_feat = y[:, patch_gen.num_skip:]
all_feat = y[:, patch_gen.num_skip :]
return all_feat