Upstream Llama4 Support to Main (#16113)

Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: Chris Thi <chris.c.thi@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Xiaodong Wang <xdwang@meta.com>
Signed-off-by: Yang Chen <yangche@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Lu Fang
2025-04-07 08:06:27 -07:00
committed by GitHub
parent 8017c8db7f
commit 55dcce91df
43 changed files with 2436 additions and 155 deletions

View File

@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
@@ -65,6 +65,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
@@ -292,7 +294,7 @@ class LlamaModel(nn.Module):
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -466,10 +468,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm"
"norm": "model.norm",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
@@ -478,7 +484,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
@@ -513,8 +520,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return LlamaModel(vllm_config=vllm_config, prefix=prefix)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

View File

@@ -0,0 +1,531 @@
# SPDX-License-Identifier: Apache-2.0
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import Llama4TextConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter)
class Llama4MoE(nn.Module):
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype)
return (router_scores, router_indices.to(torch.int32))
def __init__(self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.router")
self.experts = FusedMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
custom_routing_function=Llama4MoE.custom_routing_function,
intermediate_size=intermediate_size_moe,
apply_router_weight_on_input=True,
reduce_results=False,
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size_moe,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False, # We need to do scatter before reduce
)
def forward(self, hidden_states):
router_logits, _ = self.router(hidden_states)
shared_out = self.shared_expert(hidden_states)
routed_out = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
)
experts_out = routed_out + shared_out
if self.tp_size > 1:
experts_out = tensor_model_parallel_all_reduce(experts_out)
return experts_out
class Llama4Attention(nn.Module):
def __init__(self,
config: Llama4TextConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
self.no_rope_layers = config.no_rope_layers
self.nope = self.no_rope_layers[self.layer_idx] == 0
self.use_qk_norm = config.use_qk_norm and not self.nope
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
# TODO: attn_temperature_tuning should be a bool in huggingface
self.attn_temperature_tuning = self.nope and \
config.attn_temperature_tuning > 0
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.q_norm = RMSNorm(
hidden_size=self.q_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.k_norm = RMSNorm(
hidden_size=self.kv_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
) if not self.nope else None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=None,
use_irope=not self.nope,
prefix=f"{prefix}.attn",
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
return attn_scale.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.q_norm is not None:
q = self.q_norm(q.float()).to(q.dtype)
if self.k_norm is not None:
k = self.k_norm(k.float()).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function
# is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
#
# We should apply temperature tuning between (after) rotary / QK norm
# and (before) attention.
if self.attn_temperature_tuning and self.nope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.self_attn = Llama4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
bias_o_proj=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
is_moe_layer = (self.layer_idx +
1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
else:
self.feed_forward = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size_mlp,
hidden_act="silu",
quant_config=quant_config,
bias=False,
prefix=f"{prefix}.feed_forward",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@support_torch_compile
class Llama4Model(LlamaModel):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def load_moe_expert_weights(
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter],
loaded_params: Set[str],
expert_params_mapping: List[Tuple[str, str, int, str]],
fused: bool = True,
) -> bool:
expert_param_loaded = False
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-1)
for (param_name, weight_name, expert_id,
shard_id) in expert_params_mapping:
new_loaded_weight = loaded_weight
if fused:
e_str, _, proj_str, _ = weight_name.split('.')
weight_name = f"{e_str}.{proj_str}"
param_name = f"{param_name}weight"
if weight_name not in name:
continue
full_param_name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[full_param_name]
weight_loader = param.weight_loader
if fused:
if "w13" in full_param_name:
shard_idx = 0 if shard_id == "w1" else 1
new_loaded_weight = new_loaded_weight[shard_idx]
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
layer_idx = extract_layer_index(name)
# EP mapping
expert_map = self.layers[
layer_idx].feed_forward.experts.expert_map
if expert_map is not None:
local_expert_indices = (expert_map != -1) \
.nonzero() \
.flatten() \
.to(new_loaded_weight.device)
new_loaded_weight = new_loaded_weight[local_expert_indices]
expert_id = local_expert_indices[0].item()
else:
# TODO: add EP support for non fused weights
pass
weight_loader(param,
new_loaded_weight,
full_param_name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(full_param_name)
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
fused_experts_params = False
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.num_experts)
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="gate_up_proj",
num_experts=1)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True
expert_params_mapping = expert_params_mapping_fused
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or "experts" in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
moe_loaded = self.load_moe_expert_weights(
name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params)
if not moe_loaded:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Llama4ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Update temperature tuning config from generation config
gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config)
vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get("attn_temperature_tuning", False)
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=Llama4DecoderLayer)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
return Llama4Model(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
return loader.load_weights(weights)
def permute_qk_weight_for_rotary(
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
modules = name.split(".")
# rotary embeds should be sliced
if ("wk" in modules or "k_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
elif ("wq" in modules or "q_proj" in modules) \
and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
return name, loaded_weight

View File

@@ -0,0 +1,895 @@
# SPDX-License-Identifier: Apache-2.0
#
# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Iterable, Mapping
from functools import cached_property
from itertools import tee
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict
from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import (
find_supported_resolutions, get_best_fit)
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llama4 import Llama4ForCausalLM
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
class Llama4ImagePatchInputs(TypedDict):
type: Literal["pixel_values"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_chunks, num_channels, image size, image size)`
"""
patches_per_image: torch.Tensor
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
"""
A list of aspect ratios corresponding to the number of tiles
in each dimension that each image in the batch corresponds to.
Shape:
`(batch_size, ratio)` where ratio is a pair `(ratio_h, ratio_w)`
"""
class Llama4VisionMLP(nn.Module):
def __init__(self,
input_size: int,
intermediate_size: int,
output_size: int,
bias: bool,
output_activation: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.fc1 = ColumnParallelLinear(
input_size=input_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
input_size=intermediate_size,
output_size=output_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.activation_fn = nn.GELU()
self.output_activation = output_activation
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
if self.output_activation:
return self.activation_fn(hidden_states)
return hidden_states
class Llama4MultiModalProjector(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.linear_1 = ColumnParallelLinear(
input_size=config.vision_config.vision_output_dim,
output_size=config.text_config.hidden_size,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.linear_1",
)
def forward(self, image_features):
hidden_states, _ = self.linear_1(image_features)
return hidden_states
def pixel_shuffle(input_tensor, shuffle_ratio):
# input_tensor: [batch_size, num_patches, channels]
batch_size, num_patches, channels = input_tensor.shape
patch_size = int(math.sqrt(num_patches))
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
batch_size, height, width, channels = input_tensor.size()
reshaped_tensor = input_tensor.view(batch_size, height,
int(width * shuffle_ratio),
int(channels / shuffle_ratio))
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view(batch_size,
int(height * shuffle_ratio),
int(width * shuffle_ratio),
int(channels / (shuffle_ratio**2)))
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
output_tensor = reshaped_tensor.view(batch_size, -1,
reshaped_tensor.shape[-1])
return output_tensor
class Llama4VisionPixelShuffleMLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
self.inner_dim = int(config.projector_input_dim //
(self.pixel_shuffle_ratio**2))
self.output_dim = config.projector_output_dim
self.mlp = Llama4VisionMLP(
input_size=config.intermediate_size,
intermediate_size=config.projector_input_dim,
output_size=config.projector_output_dim,
bias=config.multi_modal_projector_bias,
output_activation=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches,
self.pixel_shuffle_ratio)
return self.mlp(encoded_patches)
class Llama4VisionAttention(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
assert self.num_heads % self.tp_size == 0
self.num_local_heads = self.num_heads // self.tp_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
self.scaling)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=True,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches
max_position=(config.image_size // config.patch_size)**2,
base=config.rope_theta,
rope_scaling={"rope_type": "mllama4"},
is_neox_style=False,
dtype=torch.complex64, # important
)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_shape = hidden_states.shape[:-1]
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim)
k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim)
q, k = self.rotary_emb(q, k)
q = q.view(q.shape[0], q.shape[1], -1)
k = k.view(k.shape[0], k.shape[1], -1)
attn_output = self.attn(q, k, v)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output, _ = self.o_proj(attn_output)
return attn_output
class Llama4VisionEncoderLayer(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.intermediate_size = config.intermediate_size
self.self_attn = Llama4VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Llama4VisionMLP(input_size=config.hidden_size,
intermediate_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
output_activation=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(config.hidden_size)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
def forward(
self,
hidden_state: torch.Tensor,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state)
hidden_state = residual + hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
hidden_state = residual + hidden_state
outputs = (hidden_state, )
return outputs
class Llama4VisionEncoder(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Llama4VisionEncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
) for layer_idx in range(config.num_hidden_layers)
])
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if you
want more control over how to convert `input_ids` indices into
associated vectors than the model's internal embedding
lookup matrix.
"""
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states)
hidden_states = layer_outputs[0]
return hidden_states
class Llama4UnfoldConvolution(nn.Module):
def __init__(self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
kernel_size = config.patch_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
stride=config.patch_size)
self.linear = ColumnParallelLinear(config.num_channels *
kernel_size[0] * kernel_size[1],
config.hidden_size,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.linear")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states, _ = self.linear(hidden_states)
return hidden_states
class Llama4VisionModel(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.image_size = config.image_size
self.patch_size = config.patch_size
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.num_patches = (self.image_size // self.patch_size)**2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = Llama4UnfoldConvolution(
config,
quant_config=quant_config,
prefix=f"{prefix}.patch_embedding")
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
self.positional_embedding_vlm = nn.Parameter(
self.scale * torch.randn(self.num_patches, self.hidden_size))
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5)
# encoders
self.model = Llama4VisionEncoder(config,
quant_config=quant_config,
prefix=f"{prefix}.model")
self.vision_adapter = Llama4VisionPixelShuffleMLP(
config, quant_config, prefix=f"{prefix}.vision_adapter")
def forward(
self,
images_flattened: torch.Tensor,
) -> torch.Tensor:
# Patch embedding
hidden_state = self.patch_embedding(images_flattened)
num_tiles, num_patches, hidden_dim = hidden_state.shape
# Add cls token
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1,
hidden_state.shape[-1])
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
num_patches += 1
# Position embeddings
hidden_state = hidden_state.reshape(
num_tiles,
1,
num_patches,
hidden_dim,
)
positional_embedding = self.positional_embedding_vlm.to(
dtype=hidden_state.dtype, device=hidden_state.device)
hidden_state = hidden_state + positional_embedding
hidden_state = self.layernorm_pre(hidden_state)
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
# Apply encoder
hidden_state = self.model(hidden_state)
hidden_state = self.layernorm_post(hidden_state)
# Remove CLS token output
hidden_state = hidden_state[:, :-1, :]
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state = self.vision_adapter(hidden_state)
return hidden_state
class Mllama4ProcessingInfo(BaseProcessingInfo):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(ctx)
def get_hf_config(self) -> Llama4Config:
return self.ctx.get_hf_config(Llama4Config)
def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
return self.ctx.get_hf_processor(Llama4Processor,
use_fast=True,
**kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 10}
@staticmethod
def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
image_size = vision_config.image_size
patch_size = vision_config.patch_size
assert (
image_size %
patch_size == 0), f"chunk size {image_size} should be multiple of "
f"patch_size {patch_size}"
ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
return (image_size // patch_size)**2 // ds_ratio
def get_max_num_tiles(self) -> int:
image_processor = self.get_hf_processor().image_processor
return image_processor.max_patches
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
# image_start + local tiles * (patches + 1 x separator) +
# 1 global tile * (image x 1 + patches) + image_end
token_per_chunk = self.get_patch_per_chunk(vision_config) + 1
mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2
return {"image": mm_max_tokens}
def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=self.get_max_num_tiles() * image_size,
width=image_size)
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if mm_data is None:
return tokenizer(prompt, add_special_tokens=False) # exclude bos
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
processor = self.info.get_hf_processor(**mm_kwargs)
image_processor = processor.image_processor
vision_config = self.info.get_hf_config().vision_config
if processed_outputs.get("pixel_values") is not None:
assert "images" in mm_data, \
"images expected to be in mm_data when pixel_values is present"
images = mm_data["images"]
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
tile_size = vision_config.image_size
possible_resolutions = find_supported_resolutions(
max_num_chunks=self.info.get_max_num_tiles(),
patch_size=SizeDict(height=tile_size, width=tile_size),
)
best_fit_sizes = [
get_best_fit(
(image.size[1], image.size[0]),
torch.tensor(possible_resolutions),
resize_to_max_canvas=image_processor.resize_to_max_canvas)
for image in parsed_images
]
# TODO tile height/width do not necessarily need to match
aspect_ratios = [(image_size[0] // tile_size,
image_size[1] // tile_size)
for image_size in best_fit_sizes]
patches_per_image = [
1 if r_h * r_w == 1 else 1 + r_h * r_w
for (r_h, r_w) in aspect_ratios
]
# embed_is_patch should have one feature per image-related token:
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
# -> False
# <|patch|> -> True
# embed_is_patch has no entries corresponding to non-image-related
# tokens.
patch_id = tokenizer.get_vocab()[processor.img_patch_token]
num_patches_per_chunk = self.info.get_patch_per_chunk(
vision_config)
expanded_image_tokens_list = [
processor._prompt_split_image(aspect_ratio,
num_patches_per_chunk)
for aspect_ratio in aspect_ratios
]
expanded_image_token_ids = [
tokenizer.encode(image_tokens, add_special_tokens=False)
for image_tokens in expanded_image_tokens_list
]
embed_is_patch = [
torch.tensor(tokens) == patch_id
for tokens in expanded_image_token_ids
]
processed_outputs["aspect_ratios"] = aspect_ratios
processed_outputs["patches_per_image"] = torch.tensor(
patches_per_image)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image),
patches_per_image=MultiModalFieldConfig.batched("image"),
aspect_ratios=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> List[PromptUpdate]:
assert (
mm_items.get_count("image", strict=False) == 0
or "aspect_ratios" in out_mm_kwargs
), "Transformers expect to include aspect_ratios in out_mm_kwargs"
config = self.info.get_hf_config()
vision_config = config.vision_config
num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token
def get_replacement(item_idx: int):
aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx]
return hf_processor._prompt_split_image(
aspect_ratio=aspect_ratio,
num_patches_per_chunk=num_patches_per_chunk)
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement,
)
]
class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
(target_width,
target_height) = self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
image_token = self.info.get_hf_processor().fake_image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor(
Mllama4MultiModalProcessor,
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_model = Llama4VisionModel(config.vision_config,
None,
prefix=maybe_prefix(
prefix, "vision_model"))
self.multi_modal_projector = Llama4MultiModalProjector(
self.config,
None,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.language_model = _initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
model_class=Llama4ForCausalLM,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
# num_images, 1, num_chunks, channel, image_size, image_size
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
# num_images x num_chunks, channel, image_size, image_size
# TODO: confirm handling for variable lengths
flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
aspect_ratios = kwargs.pop("aspect_ratios", None)
if not isinstance(aspect_ratios, (torch.Tensor, list)):
raise ValueError("Incorrect type of aspect_ratios. "
f"Got type: {type(aspect_ratios)}")
return Llama4ImagePatchInputs(
type="pixel_values",
flat_data=flat_pixel_values,
patches_per_image=patches_per_image,
embed_is_patch=embed_is_patch,
aspect_ratios=aspect_ratios,
)
def _process_image_input(
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
vision_embeddings_flat = self.vision_model(flat_data)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
# num_images x [num_chunks, num_patches, hidden_dim]
image_features = self._process_image_input(image_input)
# num_images x [num_chunks x num_patches, hidden_dim]
image_features_flat = [img.flatten(0, 1) for img in image_features]
# num_images x [1, input_len] -> num_images x [input_len]
embed_is_patch_flat = [
is_patch.flatten(0, 1)
for is_patch in image_input["embed_is_patch"]
]
return scatter_patch_features(
image_features_flat,
embed_is_patch_flat,
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
multimodal_embeddings = torch.cat(multimodal_embeddings)
mm_embeddings = self.multi_modal_projector(multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, select_patch_features(mm_embeddings),
self.config.image_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner,
# this condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
return self.language_model(input_ids, positions, intermediate_tensors,
inputs_embeds)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def separate_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str,
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[
str, torch.Tensor]]]:
weights1, weights2 = tee(weights, 2)
def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]:
for name, data in weights1:
if name.startswith(prefix):
yield (name, data)
def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]:
for name, data in weights2:
if not name.startswith(prefix):
yield (name, data)
return get_prefix_weights(), get_other_weights()
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: Set[str] = set()
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_weights, other_weights = self.separate_weights(
weights, prefix="language_model.model.")
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(
language_model_weights)
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
for name, loaded_weight in other_weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params

View File

@@ -196,6 +196,7 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}

View File

@@ -19,9 +19,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Set, Tuple, Type
from typing import Iterable, Set, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -124,7 +125,7 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
layer_type: type[nn.Module] = LlamaDecoderLayer):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
def load_weights(self, weights: Iterable[Tuple[str,

View File

@@ -22,9 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Type
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -39,7 +38,7 @@ class TeleFLMModel(LlamaModel):
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer,
layer_type: type[nn.Module] = LlamaDecoderLayer,
):
super().__init__(vllm_config=vllm_config,
prefix=prefix,