Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,8 +9,12 @@ from typing import Annotated, Any, Literal, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
ChameleonVQVAEConfig)
from transformers import (
BatchFeature,
ChameleonConfig,
ChameleonProcessor,
ChameleonVQVAEConfig,
)
from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig
@@ -19,33 +23,53 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
default_weight_loader,
row_parallel_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
@@ -58,12 +82,12 @@ class ChameleonImagePixelInputs(TensorSchema):
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class ChameleonProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)
@@ -78,9 +102,7 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
return processor.image_seq_length
class ChameleonDummyInputsBuilder(
BaseDummyInputsBuilder[ChameleonProcessingInfo]):
class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@@ -103,17 +125,16 @@ class ChameleonDummyInputsBuilder(
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=width,
height=height,
num_images=num_images,
overrides=image_overrides)
"image": self._get_dummy_images(
width=width,
height=height,
num_images=num_images,
overrides=image_overrides,
)
}
class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]):
class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
@@ -182,29 +203,23 @@ class ChameleonMultiModalProcessor(
class ChameleonLayerNorm(nn.LayerNorm):
def __init__(self, hidden_size, *args, **kwargs):
super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1], )
self.normalized_shape = (hidden_size[-1],)
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
set_weight_attrs(self.bias,
{"weight_loader": row_parallel_weight_loader})
set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states,
self.normalized_shape,
None,
None,
eps=1e-5)
hidden_states = F.layer_norm(
hidden_states, self.normalized_shape, None, None, eps=1e-5
)
hidden_states = hidden_states * self.weight + self.bias
return hidden_states
# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
class ChameleonMLP(nn.Module):
def __init__(
self,
hidden_size: int,
@@ -218,14 +233,18 @@ class ChameleonMLP(nn.Module):
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config)
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config)
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
@@ -237,7 +256,6 @@ class ChameleonMLP(nn.Module):
# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
class ChameleonAttention(nn.Module):
def __init__(
self,
hidden_size: int,
@@ -298,16 +316,19 @@ class ChameleonAttention(nn.Module):
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# reshape for layernorm
q = q.reshape(-1, self.num_heads, self.head_dim)
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
@@ -333,7 +354,6 @@ class ChameleonAttention(nn.Module):
class ChameleonDecoderLayer(nn.Module):
def __init__(
self,
config: ChameleonConfig,
@@ -346,17 +366,19 @@ class ChameleonDecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
4096)
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = ChameleonAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
@@ -372,10 +394,10 @@ class ChameleonDecoderLayer(nn.Module):
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
@@ -383,28 +405,24 @@ class ChameleonDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class ChameleonSwinDecoderLayer(nn.Module):
def __init__(
self,
config: ChameleonConfig,
@@ -417,17 +435,19 @@ class ChameleonSwinDecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
4096)
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = ChameleonAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
@@ -443,10 +463,10 @@ class ChameleonSwinDecoderLayer(nn.Module):
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
@@ -454,7 +474,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.self_attn(
positions=positions,
@@ -475,7 +494,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
class ChameleonVQVAEVectorQuantizer(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.num_embeddings = config.num_embeddings
@@ -491,55 +509,52 @@ class ChameleonVQVAEVectorQuantizer(nn.Module):
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
distances = (
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
torch.sum(self.embedding.weight**2, dim=1) -
2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
self.embedding.weight.transpose(0, 1)))
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2
* torch.einsum(
"bd,dn->bn",
hidden_state_flattened,
self.embedding.weight.transpose(0, 1),
)
)
min_encoding_indices = torch.argmin(distances, dim=1)
hidden_state_quant = self.embedding(min_encoding_indices).view(
hidden_state.shape)
hidden_state.shape
)
# compute loss for embedding
loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
2) + self.beta * torch.mean(
(hidden_state_quant - hidden_state.detach())**2)
loss = torch.mean(
(hidden_state_quant.detach() - hidden_state) ** 2
) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2)
# preserve gradients
hidden_state_quant = hidden_state + (hidden_state_quant -
hidden_state).detach()
hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
# reshape back to match original input shape
hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
2).contiguous()
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
return hidden_state_quant, loss, min_encoding_indices
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, hidden_states: torch.Tensor):
# no asymmetric padding in torch conv, must do it ourselves
hidden_states = F.pad(hidden_states,
pad=(0, 1, 0, 1),
mode="constant",
value=0)
hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
def __init__(
self,
config: ChameleonVQVAEConfig,
@@ -549,42 +564,31 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None \
else out_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.norm2 = torch.nn.GroupNorm(num_groups=32,
num_channels=out_channels,
eps=1e-6,
affine=True)
self.norm1 = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = torch.nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
)
self.dropout = torch.nn.Dropout(config.dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, hidden_states: torch.Tensor):
residual = hidden_states
@@ -608,35 +612,25 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, hidden_states: torch.Tensor):
residual = hidden_states
@@ -647,20 +641,20 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
# compute attention
batch_size, channels, height, width = query_states.shape
query_states = query_states.reshape(batch_size, channels,
height * width).permute(0, 2, 1)
query_states = query_states.reshape(
batch_size, channels, height * width
).permute(0, 2, 1)
key_states = key_states.reshape(batch_size, channels, height * width)
attn_weights = torch.bmm(query_states, key_states)
attn_weights = attn_weights * (int(channels)**(-0.5))
attn_weights = attn_weights * (int(channels) ** (-0.5))
attn_weights = F.softmax(attn_weights, dim=2)
# attend to values
value_states = value_states.reshape(batch_size, channels,
height * width)
value_states = value_states.reshape(batch_size, channels, height * width)
attn_weights = attn_weights.permute(0, 2, 1)
attn_output = torch.bmm(value_states,
attn_weights).reshape(batch_size, channels,
height, width)
attn_output = torch.bmm(value_states, attn_weights).reshape(
batch_size, channels, height, width
)
attn_output = self.proj_out(attn_output)
return residual + attn_output
@@ -668,7 +662,6 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
class ChameleonVQVAEEncoder(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
@@ -681,14 +674,12 @@ class ChameleonVQVAEEncoder(nn.Module):
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
self.conv_in = torch.nn.Conv2d(in_channels,
base_channels,
kernel_size=3,
stride=1,
padding=1)
self.conv_in = torch.nn.Conv2d(
in_channels, base_channels, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_channel_multiplier = (1, ) + tuple(channel_multiplier)
in_channel_multiplier = (1,) + tuple(channel_multiplier)
self.in_channel_multiplier = in_channel_multiplier
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
@@ -702,11 +693,14 @@ class ChameleonVQVAEEncoder(nn.Module):
config=config,
in_channels=block_in,
out_channels=block_out,
))
)
)
block_in = block_out
if (config.attn_resolutions is not None
and curr_res in config.attn_resolutions
and config.attn_type == "vanilla"):
if (
config.attn_resolutions is not None
and curr_res in config.attn_resolutions
and config.attn_type == "vanilla"
):
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
down = nn.Module()
@@ -723,18 +717,20 @@ class ChameleonVQVAEEncoder(nn.Module):
in_channels=block_in,
out_channels=block_in,
)
self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
block_in) if config.attn_type == "vanilla" else nn.Identity()
self.mid.attn_1 = (
ChameleonVQVAEEncoderAttnBlock(block_in)
if config.attn_type == "vanilla"
else nn.Identity()
)
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.norm_out = torch.nn.GroupNorm(num_groups=32,
num_channels=block_in,
eps=1e-6,
affine=True)
self.norm_out = torch.nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * latent_channels if double_latent else latent_channels,
@@ -750,15 +746,12 @@ class ChameleonVQVAEEncoder(nn.Module):
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](
hidden_states[-1])
hidden_state = self.down[i_level].block[i_block](hidden_states[-1])
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](
hidden_state)
hidden_state = self.down[i_level].attn[i_block](hidden_state)
hidden_states.append(hidden_state)
if i_level != self.num_resolutions - 1:
hidden_states.append(self.down[i_level].downsample(
hidden_states[-1]))
hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
# middle
last_hidden_state = hidden_states[-1]
@@ -775,15 +768,14 @@ class ChameleonVQVAEEncoder(nn.Module):
# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
class ChameleonVQVAE(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels,
config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
config.latent_channels, 1)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(
config.embed_dim, config.latent_channels, 1
)
self.eval() # Chameleon's VQ model is frozen
def encode(
@@ -811,10 +803,9 @@ class ChameleonImageVocabularyMapping:
@cached_property
def image_tokens(self):
return sorted([
val for name, val in self.vocab_map.items()
if name.startswith("IMGIMG")
])
return sorted(
[val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]
)
@cached_property
def bpe2img(self):
@@ -822,13 +813,10 @@ class ChameleonImageVocabularyMapping:
def remap(old_name: str) -> str:
return "".join(
img_tkn_chr_mapping.get(c, c)
for c in old_name[len("IMGIMG"):-1])
img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
)
return {
tok: int(remap(self.val2name[tok]))
for tok in self.image_tokens
}
return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
@cached_property
def img2bpe(self):
@@ -837,7 +825,8 @@ class ChameleonImageVocabularyMapping:
@cached_property
def bpe2img_search_tensors(self):
return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
sorted(self.bpe2img.values()))
sorted(self.bpe2img.values())
)
@cached_property
def img2bpe_mapping_tensor(self):
@@ -853,7 +842,6 @@ class ChameleonImageVocabularyMapping:
class ChameleonModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -867,25 +855,29 @@ class ChameleonModel(nn.Module):
self.vocab_size,
config.hidden_size,
)
self.vocabulary_mapping = ChameleonImageVocabularyMapping(
config.vocabulary_map)
decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
decoder_layer = (
ChameleonDecoderLayer
if not self.config.swin_norm
else ChameleonSwinDecoderLayer
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: decoder_layer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE(config.vq_config)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -926,10 +918,9 @@ class ChameleonModel(nn.Module):
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -937,14 +928,16 @@ class ChameleonModel(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(
ChameleonMultiModalProcessor,
info=ChameleonProcessingInfo,
dummy_inputs=ChameleonDummyInputsBuilder)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsQuant):
dummy_inputs=ChameleonDummyInputsBuilder,
)
class ChameleonForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
"gate_up_proj": ["gate_proj", "up_proj"],
}
@classmethod
@@ -960,8 +953,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.model = ChameleonModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = ChameleonModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
@@ -972,13 +966,16 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
self, **kwargs: object
) -> Optional[ChameleonImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
@@ -987,24 +984,23 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
vq_config: ChameleonVQVAEConfig = self.config.vq_config
expected_h = expected_w = vq_config.resolution
return ChameleonImagePixelInputs(type="pixel_values",
data=pixel_values,
resolve_bindings={
"h": expected_h,
"w": expected_w
})
return ChameleonImagePixelInputs(
type="pixel_values",
data=pixel_values,
resolve_bindings={"h": expected_h, "w": expected_w},
)
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype))
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.torch_dtype)
)
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings
@@ -1016,14 +1012,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
@@ -1040,8 +1034,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -1056,8 +1049,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
@@ -1075,8 +1067,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
# not vqvae for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -1096,7 +1087,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
".kv_scale", ".attn.kv_scale"
)
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
@@ -1109,15 +1101,15 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if use_default_weight_loading and name in params_dict:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params