Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -27,30 +27,39 @@ from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict
from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import (
find_supported_resolutions, get_best_fit)
find_supported_resolutions,
get_best_fit,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -72,9 +81,10 @@ class Llama4ImagePatchInputs(TensorSchema):
type: Literal["pixel_values"] = "pixel_values"
flat_data: Annotated[torch.Tensor,
TensorShape("total_num_chunks", "num_channels",
"image_size", "image_size")]
flat_data: Annotated[
torch.Tensor,
TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
]
patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
"""
@@ -93,7 +103,6 @@ class Llama4ImagePatchInputs(TensorSchema):
class Llama4VisionMLP(nn.Module):
def __init__(
self,
input_size: int,
@@ -135,7 +144,6 @@ class Llama4VisionMLP(nn.Module):
class Llama4MultiModalProjector(nn.Module):
def __init__(
self,
config,
@@ -165,9 +173,9 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
batch_size, height, width, channels = input_tensor.size()
reshaped_tensor = input_tensor.view(batch_size, height,
int(width * shuffle_ratio),
int(channels / shuffle_ratio))
reshaped_tensor = input_tensor.view(
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
reshaped_tensor = reshaped_tensor.view(
@@ -178,13 +186,11 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
)
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
output_tensor = reshaped_tensor.view(batch_size, -1,
reshaped_tensor.shape[-1])
output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
return output_tensor
class Llama4VisionPixelShuffleMLP(nn.Module):
def __init__(
self,
config,
@@ -194,8 +200,9 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
):
super().__init__()
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
self.inner_dim = int(config.projector_input_dim //
(self.pixel_shuffle_ratio**2))
self.inner_dim = int(
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
)
self.output_dim = config.projector_output_dim
self.mlp = Llama4VisionMLP(
input_size=config.intermediate_size,
@@ -209,13 +216,11 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
)
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
encoded_patches = pixel_shuffle(encoded_patches,
self.pixel_shuffle_ratio)
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
return self.mlp(encoded_patches)
class Llama4VisionAttention(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
@@ -225,8 +230,9 @@ class Llama4VisionAttention(nn.Module):
):
super().__init__()
self.config = config
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
@@ -237,8 +243,9 @@ class Llama4VisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
self.scaling)
self.attn = MultiHeadAttention(
self.num_local_heads, self.head_dim, self.scaling
)
if use_data_parallel:
self.qkv_proj = ReplicatedLinear(
@@ -277,7 +284,7 @@ class Llama4VisionAttention(nn.Module):
head_size=self.head_dim,
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
# number of image patches
max_position=(config.image_size // config.patch_size)**2,
max_position=(config.image_size // config.patch_size) ** 2,
base=config.rope_theta,
rope_scaling={"rope_type": "mllama4"},
is_neox_style=False,
@@ -308,7 +315,6 @@ class Llama4VisionAttention(nn.Module):
class Llama4VisionEncoderLayer(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
@@ -357,12 +363,11 @@ class Llama4VisionEncoderLayer(nn.Module):
hidden_state = self.mlp(hidden_state)
hidden_state = residual + hidden_state
outputs = (hidden_state, )
outputs = (hidden_state,)
return outputs
class Llama4VisionEncoder(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
@@ -372,14 +377,17 @@ class Llama4VisionEncoder(nn.Module):
):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Llama4VisionEncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
) for layer_idx in range(config.num_hidden_layers)
])
self.layers = nn.ModuleList(
[
Llama4VisionEncoderLayer(
config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(config.num_hidden_layers)
]
)
def forward(
self,
@@ -387,9 +395,9 @@ class Llama4VisionEncoder(nn.Module):
) -> torch.Tensor:
r"""
Args:
hidden_states: Input tensor of shape
hidden_states: Input tensor of shape
(batch_size, sequence_length, hidden_size).
Hidden states from the model embeddings, representing
Hidden states from the model embeddings, representing
the input tokens.
associated vectors than the model's internal embedding
lookup matrix.
@@ -403,7 +411,6 @@ class Llama4VisionEncoder(nn.Module):
class Llama4UnfoldConvolution(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
@@ -415,8 +422,7 @@ class Llama4UnfoldConvolution(nn.Module):
kernel_size = config.patch_size
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
stride=config.patch_size)
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
self.linear = ColumnParallelLinear(
input_size=config.num_channels * kernel_size[0] * kernel_size[1],
output_size=config.hidden_size,
@@ -435,7 +441,6 @@ class Llama4UnfoldConvolution(nn.Module):
class Llama4VisionModel(nn.Module):
def __init__(
self,
config: Llama4VisionConfig,
@@ -450,7 +455,7 @@ class Llama4VisionModel(nn.Module):
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.num_patches = (self.image_size // self.patch_size)**2 + 1
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = Llama4UnfoldConvolution(
@@ -460,10 +465,10 @@ class Llama4VisionModel(nn.Module):
use_data_parallel=use_data_parallel,
)
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
self.positional_embedding_vlm = nn.Parameter(
self.scale * torch.randn(self.num_patches, self.hidden_size))
self.scale * torch.randn(self.num_patches, self.hidden_size)
)
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
@@ -492,8 +497,9 @@ class Llama4VisionModel(nn.Module):
num_tiles, num_patches, hidden_dim = hidden_state.shape
# Add cls token
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1,
hidden_state.shape[-1])
class_embedding = self.class_embedding.expand(
hidden_state.shape[0], 1, hidden_state.shape[-1]
)
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
num_patches += 1
@@ -505,7 +511,8 @@ class Llama4VisionModel(nn.Module):
hidden_dim,
)
positional_embedding = self.positional_embedding_vlm.to(
dtype=hidden_state.dtype, device=hidden_state.device)
dtype=hidden_state.dtype, device=hidden_state.device
)
hidden_state = hidden_state + positional_embedding
hidden_state = self.layernorm_pre(hidden_state)
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
@@ -524,7 +531,6 @@ class Llama4VisionModel(nn.Module):
class Mllama4ProcessingInfo(BaseProcessingInfo):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(ctx)
@@ -532,9 +538,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
return self.ctx.get_hf_config(Llama4Config)
def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
return self.ctx.get_hf_processor(Llama4Processor,
use_fast=kwargs.pop("use_fast", True),
**kwargs)
return self.ctx.get_hf_processor(
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
# Although vLLM can support more images from an infra capability
@@ -546,13 +552,13 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
image_size = vision_config.image_size
patch_size = vision_config.patch_size
assert (
image_size %
patch_size == 0), f"chunk size {image_size} should be multiple of "
assert image_size % patch_size == 0, (
f"chunk size {image_size} should be multiple of "
)
f"patch_size {patch_size}"
ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
return (image_size // patch_size)**2 // ds_ratio
return (image_size // patch_size) ** 2 // ds_ratio
def get_max_num_tiles(self) -> int:
image_processor = self.get_hf_processor().image_processor
@@ -562,13 +568,10 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=self.get_max_num_tiles() * image_size,
width=image_size)
return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size)
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
):
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
@@ -592,15 +595,16 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
vision_config = self.info.get_hf_config().vision_config
if processed_outputs.get("pixel_values") is not None:
assert (
"images" in mm_data
), "images expected to be in mm_data when pixel_values is present"
assert "images" in mm_data, (
"images expected to be in mm_data when pixel_values is present"
)
images = mm_data["images"]
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
parsed_images = (
self._get_data_parser()
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
)
tile_size = vision_config.image_size
possible_resolutions = find_supported_resolutions(
@@ -612,20 +616,20 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
(image.size[1], image.size[0]),
torch.tensor(possible_resolutions),
resize_to_max_canvas=image_processor.resize_to_max_canvas,
) for image in parsed_images
)
for image in parsed_images
]
# TODO tile height/width do not necessarily need to match
aspect_ratios = [(image_size[0] // tile_size,
image_size[1] // tile_size)
for image_size in best_fit_sizes]
aspect_ratios = [
(image_size[0] // tile_size, image_size[1] // tile_size)
for image_size in best_fit_sizes
]
patches_per_image = [
1 if r_h * r_w == 1 else 1 + r_h * r_w
for (r_h, r_w) in aspect_ratios
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
]
processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
processed_outputs["patches_per_image"] = torch.tensor(
patches_per_image)
processed_outputs["patches_per_image"] = torch.tensor(patches_per_image)
return processed_outputs
@@ -637,7 +641,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image),
"image", patches_per_image
),
patches_per_image=MultiModalFieldConfig.batched("image"),
aspect_ratios=MultiModalFieldConfig.batched("image"),
)
@@ -677,7 +682,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@@ -694,17 +698,17 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
(target_width,
target_height) = self.info.get_image_size_with_most_features()
(target_width, target_height) = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides)
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
@@ -713,8 +717,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -747,24 +750,26 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
use_data_parallel=self.use_data_parallel,
)
self.multi_modal_projector = Llama4MultiModalProjector(
self.config,
None,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
else:
self.vision_model = None
self.multi_modal_projector = None
self.language_model = initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config,
["LlamaForCausalLM"]),
vllm_config=vllm_config.with_hf_config(
config.text_config, ["LlamaForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
model_class=Llama4ForCausalLM,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
self, **kwargs: object
) -> Optional[Llama4ImagePatchInputs]:
# num_images, 1, num_chunks, channel, image_size, image_size
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
@@ -786,8 +791,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
)
def _process_image_input(
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
self, image_input: Llama4ImagePatchInputs
) -> MultiModalEmbeddings:
assert self.vision_model and self.multi_modal_projector
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
@@ -795,12 +800,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# shard image input
if self.use_data_parallel:
vision_embeddings_flat = run_dp_sharded_vision_model(
flat_data, self.vision_model)
flat_data, self.vision_model
)
else:
vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.multi_modal_projector(
vision_embeddings_flat)
vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)
return [
img.flatten(0, 1)
@@ -828,8 +833,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
if intermediate_tensors is not None:
inputs_embeds = None
return self.language_model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
def compute_logits(
self,
@@ -841,8 +847,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
weights: Iterable[tuple[str, torch.Tensor]],
prefix: str,
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[
str, torch.Tensor]]]:
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
weights1, weights2 = tee(weights, 2)
def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
@@ -884,31 +889,33 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
format."""
if name.startswith("model.") or name.startswith(
"language_model.model."):
renamed = name.replace("model.", "language_model.model.",
1) if name.startswith("model.") else name
if name.startswith("model.") or name.startswith("language_model.model."):
renamed = (
name.replace("model.", "language_model.model.", 1)
if name.startswith("model.")
else name
)
# Handle expert scale parameters with flat naming
if "feed_forward.experts." in name and ("_input_scale" in name or
"_weight_scale" in name):
if "feed_forward.experts." in name and (
"_input_scale" in name or "_weight_scale" in name
):
# Map checkpoint naming to vLLM's expected naming
if "down_proj_input_scale" in renamed:
return renamed.replace("down_proj_input_scale",
"w2_input_scale")
return renamed.replace("down_proj_input_scale", "w2_input_scale")
elif "down_proj_weight_scale" in renamed:
return renamed.replace("down_proj_weight_scale",
"w2_weight_scale")
return renamed.replace("down_proj_weight_scale", "w2_weight_scale")
elif "gate_up_proj_input_scale" in renamed:
return renamed.replace("gate_up_proj_input_scale",
"w13_input_scale")
return renamed.replace(
"gate_up_proj_input_scale", "w13_input_scale"
)
elif "gate_up_proj_weight_scale" in renamed:
return renamed.replace("gate_up_proj_weight_scale",
"w13_weight_scale")
return renamed.replace(
"gate_up_proj_weight_scale", "w13_weight_scale"
)
return renamed
# Handle attention scale parameters
elif "self_attn." in name and (".k_scale" in name
or ".v_scale" in name):
elif "self_attn." in name and (".k_scale" in name or ".v_scale" in name):
if ".k_proj.k_scale" in renamed:
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
elif ".v_proj.v_scale" in renamed:
@@ -919,8 +926,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return renamed
elif name.startswith("lm_head.weight"):
return name.replace("lm_head.weight",
"language_model.lm_head.weight")
return name.replace("lm_head.weight", "language_model.lm_head.weight")
return name
@@ -943,7 +949,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return language_model_weights, other_weights
def _handle_expert_scale_broadcasting(
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
"""Handle expert scale parameters that need broadcasting.
@@ -956,12 +962,18 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
for name, weight in weights:
# Check if this is an expert scale parameter that needs broadcasting
if ("feed_forward.experts." in name and "scale" in name
and ".shared_expert" not in name):
if (
"feed_forward.experts." in name
and "scale" in name
and ".shared_expert" not in name
):
if name in params_dict:
param = params_dict[name]
if (hasattr(param, 'data') and param.data.numel() > 1
and weight.numel() == 1):
if (
hasattr(param, "data")
and param.data.numel() > 1
and weight.numel() == 1
):
# Broadcast single value to all experts
param.data.fill_(weight.item())
updated_params.add(name)
@@ -973,10 +985,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
return regular_weights, expert_scale_weights, updated_params
def _load_other_weights(self, other_weights: Iterable[tuple[str,
torch.Tensor]],
params_dict: dict,
stacked_params_mapping: list) -> set[str]:
def _load_other_weights(
self,
other_weights: Iterable[tuple[str, torch.Tensor]],
params_dict: dict,
stacked_params_mapping: list,
) -> set[str]:
"""Load non-language-model weights with stacking support."""
updated_params = set()
@@ -997,16 +1011,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
else:
# Use regular weight loading
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
@@ -1023,8 +1034,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
updated_params: set[str] = set()
# Separate and rename weights
language_model_weights, other_weights = (
self._separate_and_rename_weights(weights))
language_model_weights, other_weights = self._separate_and_rename_weights(
weights
)
# Skip loading vision model and projector if they're not initialized.
if self.vision_model is None and self.multi_modal_projector is None:
@@ -1032,8 +1044,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights,
params_dict))
self._handle_expert_scale_broadcasting(language_model_weights, params_dict)
)
updated_params.update(updated_params_from_experts)
loader = AutoWeightsLoader(self)
@@ -1042,13 +1054,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
updated_params.update(loaded_language_model_params)
if expert_scale_weights:
loaded_expert_scale_params = loader.load_weights(
expert_scale_weights)
loaded_expert_scale_params = loader.load_weights(expert_scale_weights)
if loaded_expert_scale_params:
updated_params.update(loaded_expert_scale_params)
updated_params.update(
self._load_other_weights(other_weights, params_dict,
stacked_params_mapping))
self._load_other_weights(other_weights, params_dict, stacked_params_mapping)
)
return updated_params