Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,8 +9,12 @@ from typing import Annotated, Any, Literal, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
||||
ChameleonVQVAEConfig)
|
||||
from transformers import (
|
||||
BatchFeature,
|
||||
ChameleonConfig,
|
||||
ChameleonProcessor,
|
||||
ChameleonVQVAEConfig,
|
||||
)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -19,33 +23,53 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, row_parallel_weight_loader)
|
||||
default_weight_loader,
|
||||
row_parallel_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
)
|
||||
from .utils import (
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -58,12 +82,12 @@ class ChameleonImagePixelInputs(TensorSchema):
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
|
||||
|
||||
class ChameleonProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(ChameleonConfig)
|
||||
|
||||
@@ -78,9 +102,7 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
|
||||
return processor.image_seq_length
|
||||
|
||||
|
||||
class ChameleonDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[ChameleonProcessingInfo]):
|
||||
|
||||
class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
@@ -103,17 +125,16 @@ class ChameleonDummyInputsBuilder(
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=width,
|
||||
height=height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
"image": self._get_dummy_images(
|
||||
width=width,
|
||||
height=height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class ChameleonMultiModalProcessor(
|
||||
BaseMultiModalProcessor[ChameleonProcessingInfo]):
|
||||
|
||||
class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -182,29 +203,23 @@ class ChameleonMultiModalProcessor(
|
||||
|
||||
|
||||
class ChameleonLayerNorm(nn.LayerNorm):
|
||||
|
||||
def __init__(self, hidden_size, *args, **kwargs):
|
||||
super().__init__(hidden_size, *args, **kwargs)
|
||||
self.normalized_shape = (hidden_size[-1], )
|
||||
self.normalized_shape = (hidden_size[-1],)
|
||||
|
||||
set_weight_attrs(self.weight,
|
||||
{"weight_loader": row_parallel_weight_loader})
|
||||
set_weight_attrs(self.bias,
|
||||
{"weight_loader": row_parallel_weight_loader})
|
||||
set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
|
||||
set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader})
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = F.layer_norm(hidden_states,
|
||||
self.normalized_shape,
|
||||
None,
|
||||
None,
|
||||
eps=1e-5)
|
||||
hidden_states = F.layer_norm(
|
||||
hidden_states, self.normalized_shape, None, None, eps=1e-5
|
||||
)
|
||||
hidden_states = hidden_states * self.weight + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
|
||||
class ChameleonMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -218,14 +233,18 @@ class ChameleonMLP(nn.Module):
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -237,7 +256,6 @@ class ChameleonMLP(nn.Module):
|
||||
|
||||
# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
|
||||
class ChameleonAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -298,16 +316,19 @@ class ChameleonAttention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor,
|
||||
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _apply_qk_norm(
|
||||
self, q: torch.Tensor, k: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# reshape for layernorm
|
||||
q = q.reshape(-1, self.num_heads, self.head_dim)
|
||||
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
|
||||
@@ -333,7 +354,6 @@ class ChameleonAttention(nn.Module):
|
||||
|
||||
|
||||
class ChameleonDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChameleonConfig,
|
||||
@@ -346,17 +366,19 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
config, "original_max_position_embeddings", None
|
||||
):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
4096)
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
||||
|
||||
self.self_attn = ChameleonAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
@@ -372,10 +394,10 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -383,28 +405,24 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class ChameleonSwinDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChameleonConfig,
|
||||
@@ -417,17 +435,19 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
config, "original_max_position_embeddings", None
|
||||
):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
4096)
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
||||
|
||||
self.self_attn = ChameleonAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
@@ -443,10 +463,10 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -454,7 +474,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
@@ -475,7 +494,6 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
|
||||
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
|
||||
class ChameleonVQVAEVectorQuantizer(nn.Module):
|
||||
|
||||
def __init__(self, config: ChameleonVQVAEConfig):
|
||||
super().__init__()
|
||||
self.num_embeddings = config.num_embeddings
|
||||
@@ -491,55 +509,52 @@ class ChameleonVQVAEVectorQuantizer(nn.Module):
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
distances = (
|
||||
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
|
||||
torch.sum(self.embedding.weight**2, dim=1) -
|
||||
2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
|
||||
self.embedding.weight.transpose(0, 1)))
|
||||
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
|
||||
+ torch.sum(self.embedding.weight**2, dim=1)
|
||||
- 2
|
||||
* torch.einsum(
|
||||
"bd,dn->bn",
|
||||
hidden_state_flattened,
|
||||
self.embedding.weight.transpose(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
min_encoding_indices = torch.argmin(distances, dim=1)
|
||||
hidden_state_quant = self.embedding(min_encoding_indices).view(
|
||||
hidden_state.shape)
|
||||
hidden_state.shape
|
||||
)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
|
||||
2) + self.beta * torch.mean(
|
||||
(hidden_state_quant - hidden_state.detach())**2)
|
||||
loss = torch.mean(
|
||||
(hidden_state_quant.detach() - hidden_state) ** 2
|
||||
) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2)
|
||||
|
||||
# preserve gradients
|
||||
hidden_state_quant = hidden_state + (hidden_state_quant -
|
||||
hidden_state).detach()
|
||||
hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return hidden_state_quant, loss, min_encoding_indices
|
||||
|
||||
|
||||
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
|
||||
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
hidden_states = F.pad(hidden_states,
|
||||
pad=(0, 1, 0, 1),
|
||||
mode="constant",
|
||||
value=0)
|
||||
hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
|
||||
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChameleonVQVAEConfig,
|
||||
@@ -549,42 +564,31 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None \
|
||||
else out_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=out_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.norm1 = torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.norm2 = torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(config.dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
residual = hidden_states
|
||||
@@ -608,35 +612,25 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
|
||||
|
||||
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
|
||||
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.norm = torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
residual = hidden_states
|
||||
@@ -647,20 +641,20 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
|
||||
|
||||
# compute attention
|
||||
batch_size, channels, height, width = query_states.shape
|
||||
query_states = query_states.reshape(batch_size, channels,
|
||||
height * width).permute(0, 2, 1)
|
||||
query_states = query_states.reshape(
|
||||
batch_size, channels, height * width
|
||||
).permute(0, 2, 1)
|
||||
key_states = key_states.reshape(batch_size, channels, height * width)
|
||||
attn_weights = torch.bmm(query_states, key_states)
|
||||
attn_weights = attn_weights * (int(channels)**(-0.5))
|
||||
attn_weights = attn_weights * (int(channels) ** (-0.5))
|
||||
attn_weights = F.softmax(attn_weights, dim=2)
|
||||
|
||||
# attend to values
|
||||
value_states = value_states.reshape(batch_size, channels,
|
||||
height * width)
|
||||
value_states = value_states.reshape(batch_size, channels, height * width)
|
||||
attn_weights = attn_weights.permute(0, 2, 1)
|
||||
attn_output = torch.bmm(value_states,
|
||||
attn_weights).reshape(batch_size, channels,
|
||||
height, width)
|
||||
attn_output = torch.bmm(value_states, attn_weights).reshape(
|
||||
batch_size, channels, height, width
|
||||
)
|
||||
|
||||
attn_output = self.proj_out(attn_output)
|
||||
return residual + attn_output
|
||||
@@ -668,7 +662,6 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
|
||||
|
||||
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
|
||||
class ChameleonVQVAEEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config: ChameleonVQVAEConfig):
|
||||
super().__init__()
|
||||
|
||||
@@ -681,14 +674,12 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
latent_channels = config.latent_channels
|
||||
channel_multiplier = config.channel_multiplier
|
||||
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
base_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, base_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
curr_res = resolution
|
||||
in_channel_multiplier = (1, ) + tuple(channel_multiplier)
|
||||
in_channel_multiplier = (1,) + tuple(channel_multiplier)
|
||||
self.in_channel_multiplier = in_channel_multiplier
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
@@ -702,11 +693,14 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
config=config,
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
))
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if (config.attn_resolutions is not None
|
||||
and curr_res in config.attn_resolutions
|
||||
and config.attn_type == "vanilla"):
|
||||
if (
|
||||
config.attn_resolutions is not None
|
||||
and curr_res in config.attn_resolutions
|
||||
and config.attn_type == "vanilla"
|
||||
):
|
||||
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
|
||||
|
||||
down = nn.Module()
|
||||
@@ -723,18 +717,20 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
)
|
||||
self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
|
||||
block_in) if config.attn_type == "vanilla" else nn.Identity()
|
||||
self.mid.attn_1 = (
|
||||
ChameleonVQVAEEncoderAttnBlock(block_in)
|
||||
if config.attn_type == "vanilla"
|
||||
else nn.Identity()
|
||||
)
|
||||
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
|
||||
config=config,
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
)
|
||||
|
||||
self.norm_out = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=block_in,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.norm_out = torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
||||
)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * latent_channels if double_latent else latent_channels,
|
||||
@@ -750,15 +746,12 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
hidden_states = [self.conv_in(pixel_values)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
hidden_state = self.down[i_level].block[i_block](
|
||||
hidden_states[-1])
|
||||
hidden_state = self.down[i_level].block[i_block](hidden_states[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
hidden_state = self.down[i_level].attn[i_block](
|
||||
hidden_state)
|
||||
hidden_state = self.down[i_level].attn[i_block](hidden_state)
|
||||
hidden_states.append(hidden_state)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hidden_states.append(self.down[i_level].downsample(
|
||||
hidden_states[-1]))
|
||||
hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
|
||||
|
||||
# middle
|
||||
last_hidden_state = hidden_states[-1]
|
||||
@@ -775,15 +768,14 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
|
||||
# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
|
||||
class ChameleonVQVAE(nn.Module):
|
||||
|
||||
def __init__(self, config: ChameleonVQVAEConfig):
|
||||
super().__init__()
|
||||
self.encoder = ChameleonVQVAEEncoder(config)
|
||||
self.quantize = ChameleonVQVAEVectorQuantizer(config)
|
||||
self.quant_conv = torch.nn.Conv2d(config.latent_channels,
|
||||
config.embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
|
||||
config.latent_channels, 1)
|
||||
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(
|
||||
config.embed_dim, config.latent_channels, 1
|
||||
)
|
||||
self.eval() # Chameleon's VQ model is frozen
|
||||
|
||||
def encode(
|
||||
@@ -811,10 +803,9 @@ class ChameleonImageVocabularyMapping:
|
||||
|
||||
@cached_property
|
||||
def image_tokens(self):
|
||||
return sorted([
|
||||
val for name, val in self.vocab_map.items()
|
||||
if name.startswith("IMGIMG")
|
||||
])
|
||||
return sorted(
|
||||
[val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def bpe2img(self):
|
||||
@@ -822,13 +813,10 @@ class ChameleonImageVocabularyMapping:
|
||||
|
||||
def remap(old_name: str) -> str:
|
||||
return "".join(
|
||||
img_tkn_chr_mapping.get(c, c)
|
||||
for c in old_name[len("IMGIMG"):-1])
|
||||
img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
|
||||
)
|
||||
|
||||
return {
|
||||
tok: int(remap(self.val2name[tok]))
|
||||
for tok in self.image_tokens
|
||||
}
|
||||
return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
|
||||
|
||||
@cached_property
|
||||
def img2bpe(self):
|
||||
@@ -837,7 +825,8 @@ class ChameleonImageVocabularyMapping:
|
||||
@cached_property
|
||||
def bpe2img_search_tensors(self):
|
||||
return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
|
||||
sorted(self.bpe2img.values()))
|
||||
sorted(self.bpe2img.values())
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def img2bpe_mapping_tensor(self):
|
||||
@@ -853,7 +842,6 @@ class ChameleonImageVocabularyMapping:
|
||||
|
||||
|
||||
class ChameleonModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -867,25 +855,29 @@ class ChameleonModel(nn.Module):
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.vocabulary_mapping = ChameleonImageVocabularyMapping(
|
||||
config.vocabulary_map)
|
||||
decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
|
||||
self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
|
||||
decoder_layer = (
|
||||
ChameleonDecoderLayer
|
||||
if not self.config.swin_norm
|
||||
else ChameleonSwinDecoderLayer
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: decoder_layer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
lambda prefix: decoder_layer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.vqmodel = ChameleonVQVAE(config.vq_config)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@@ -926,10 +918,9 @@ class ChameleonModel(nn.Module):
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@@ -937,14 +928,16 @@ class ChameleonModel(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
ChameleonMultiModalProcessor,
|
||||
info=ChameleonProcessingInfo,
|
||||
dummy_inputs=ChameleonDummyInputsBuilder)
|
||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsQuant):
|
||||
dummy_inputs=ChameleonDummyInputsBuilder,
|
||||
)
|
||||
class ChameleonForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -960,8 +953,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.model = ChameleonModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = ChameleonModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
@@ -972,13 +966,16 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size, logit_scale
|
||||
)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[ChameleonImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if pixel_values is None:
|
||||
@@ -987,24 +984,23 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vq_config: ChameleonVQVAEConfig = self.config.vq_config
|
||||
expected_h = expected_w = vq_config.resolution
|
||||
|
||||
return ChameleonImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w
|
||||
})
|
||||
return ChameleonImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
resolve_bindings={"h": expected_h, "w": expected_w},
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
assert self.model.vqmodel is not None
|
||||
image_tokens = self.model.get_image_tokens(image_input["data"].to(
|
||||
self.config.torch_dtype))
|
||||
image_tokens = self.model.get_image_tokens(
|
||||
image_input["data"].to(self.config.torch_dtype)
|
||||
)
|
||||
vision_embeddings = self.model.get_input_embeddings(image_tokens)
|
||||
return vision_embeddings
|
||||
|
||||
@@ -1016,14 +1012,12 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@@ -1040,8 +1034,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -1056,8 +1049,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
@@ -1075,8 +1067,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# not vqvae for now.
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
for (param_name, weight_name,
|
||||
shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -1096,7 +1087,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
".kv_scale", ".attn.kv_scale"
|
||||
)
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
|
||||
@@ -1109,15 +1101,15 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if use_default_weight_loading and name in params_dict:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user