Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -27,30 +27,39 @@ from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.models.llama4 import Llama4Processor
|
||||
from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
find_supported_resolutions, get_best_fit)
|
||||
find_supported_resolutions,
|
||||
get_best_fit,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.utils import initialize_model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -72,9 +81,10 @@ class Llama4ImagePatchInputs(TensorSchema):
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
flat_data: Annotated[torch.Tensor,
|
||||
TensorShape("total_num_chunks", "num_channels",
|
||||
"image_size", "image_size")]
|
||||
flat_data: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
|
||||
]
|
||||
|
||||
patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")]
|
||||
"""
|
||||
@@ -93,7 +103,6 @@ class Llama4ImagePatchInputs(TensorSchema):
|
||||
|
||||
|
||||
class Llama4VisionMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@@ -135,7 +144,6 @@ class Llama4VisionMLP(nn.Module):
|
||||
|
||||
|
||||
class Llama4MultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
@@ -165,9 +173,9 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
|
||||
input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
|
||||
batch_size, height, width, channels = input_tensor.size()
|
||||
|
||||
reshaped_tensor = input_tensor.view(batch_size, height,
|
||||
int(width * shuffle_ratio),
|
||||
int(channels / shuffle_ratio))
|
||||
reshaped_tensor = input_tensor.view(
|
||||
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
|
||||
)
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
reshaped_tensor = reshaped_tensor.view(
|
||||
@@ -178,13 +186,11 @@ def pixel_shuffle(input_tensor, shuffle_ratio):
|
||||
)
|
||||
reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
output_tensor = reshaped_tensor.view(batch_size, -1,
|
||||
reshaped_tensor.shape[-1])
|
||||
output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
|
||||
return output_tensor
|
||||
|
||||
|
||||
class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
@@ -194,8 +200,9 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
||||
self.inner_dim = int(config.projector_input_dim //
|
||||
(self.pixel_shuffle_ratio**2))
|
||||
self.inner_dim = int(
|
||||
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
|
||||
)
|
||||
self.output_dim = config.projector_output_dim
|
||||
self.mlp = Llama4VisionMLP(
|
||||
input_size=config.intermediate_size,
|
||||
@@ -209,13 +216,11 @@ class Llama4VisionPixelShuffleMLP(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = pixel_shuffle(encoded_patches,
|
||||
self.pixel_shuffle_ratio)
|
||||
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
|
||||
return self.mlp(encoded_patches)
|
||||
|
||||
|
||||
class Llama4VisionAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
@@ -225,8 +230,9 @@ class Llama4VisionAttention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
@@ -237,8 +243,9 @@ class Llama4VisionAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
|
||||
self.scaling)
|
||||
self.attn = MultiHeadAttention(
|
||||
self.num_local_heads, self.head_dim, self.scaling
|
||||
)
|
||||
|
||||
if use_data_parallel:
|
||||
self.qkv_proj = ReplicatedLinear(
|
||||
@@ -277,7 +284,7 @@ class Llama4VisionAttention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=config.hidden_size // config.num_attention_heads // 2,
|
||||
# number of image patches
|
||||
max_position=(config.image_size // config.patch_size)**2,
|
||||
max_position=(config.image_size // config.patch_size) ** 2,
|
||||
base=config.rope_theta,
|
||||
rope_scaling={"rope_type": "mllama4"},
|
||||
is_neox_style=False,
|
||||
@@ -308,7 +315,6 @@ class Llama4VisionAttention(nn.Module):
|
||||
|
||||
|
||||
class Llama4VisionEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
@@ -357,12 +363,11 @@ class Llama4VisionEncoderLayer(nn.Module):
|
||||
hidden_state = self.mlp(hidden_state)
|
||||
hidden_state = residual + hidden_state
|
||||
|
||||
outputs = (hidden_state, )
|
||||
outputs = (hidden_state,)
|
||||
return outputs
|
||||
|
||||
|
||||
class Llama4VisionEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
@@ -372,14 +377,17 @@ class Llama4VisionEncoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([
|
||||
Llama4VisionEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
) for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Llama4VisionEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -387,9 +395,9 @@ class Llama4VisionEncoder(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
hidden_states: Input tensor of shape
|
||||
hidden_states: Input tensor of shape
|
||||
(batch_size, sequence_length, hidden_size).
|
||||
Hidden states from the model embeddings, representing
|
||||
Hidden states from the model embeddings, representing
|
||||
the input tokens.
|
||||
associated vectors than the model's internal embedding
|
||||
lookup matrix.
|
||||
@@ -403,7 +411,6 @@ class Llama4VisionEncoder(nn.Module):
|
||||
|
||||
|
||||
class Llama4UnfoldConvolution(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
@@ -415,8 +422,7 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
kernel_size = config.patch_size
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size,
|
||||
stride=config.patch_size)
|
||||
self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
|
||||
self.linear = ColumnParallelLinear(
|
||||
input_size=config.num_channels * kernel_size[0] * kernel_size[1],
|
||||
output_size=config.hidden_size,
|
||||
@@ -435,7 +441,6 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
|
||||
|
||||
class Llama4VisionModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4VisionConfig,
|
||||
@@ -450,7 +455,7 @@ class Llama4VisionModel(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_channels = config.num_channels
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size)**2 + 1
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
||||
self.scale = config.hidden_size**-0.5
|
||||
|
||||
self.patch_embedding = Llama4UnfoldConvolution(
|
||||
@@ -460,10 +465,10 @@ class Llama4VisionModel(nn.Module):
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.class_embedding = nn.Parameter(self.scale *
|
||||
torch.randn(self.hidden_size))
|
||||
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
|
||||
self.positional_embedding_vlm = nn.Parameter(
|
||||
self.scale * torch.randn(self.num_patches, self.hidden_size))
|
||||
self.scale * torch.randn(self.num_patches, self.hidden_size)
|
||||
)
|
||||
|
||||
# layer norms
|
||||
self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
@@ -492,8 +497,9 @@ class Llama4VisionModel(nn.Module):
|
||||
num_tiles, num_patches, hidden_dim = hidden_state.shape
|
||||
|
||||
# Add cls token
|
||||
class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1,
|
||||
hidden_state.shape[-1])
|
||||
class_embedding = self.class_embedding.expand(
|
||||
hidden_state.shape[0], 1, hidden_state.shape[-1]
|
||||
)
|
||||
hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
|
||||
num_patches += 1
|
||||
|
||||
@@ -505,7 +511,8 @@ class Llama4VisionModel(nn.Module):
|
||||
hidden_dim,
|
||||
)
|
||||
positional_embedding = self.positional_embedding_vlm.to(
|
||||
dtype=hidden_state.dtype, device=hidden_state.device)
|
||||
dtype=hidden_state.dtype, device=hidden_state.device
|
||||
)
|
||||
hidden_state = hidden_state + positional_embedding
|
||||
hidden_state = self.layernorm_pre(hidden_state)
|
||||
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
|
||||
@@ -524,7 +531,6 @@ class Llama4VisionModel(nn.Module):
|
||||
|
||||
|
||||
class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(ctx)
|
||||
|
||||
@@ -532,9 +538,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
return self.ctx.get_hf_config(Llama4Config)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
|
||||
return self.ctx.get_hf_processor(Llama4Processor,
|
||||
use_fast=kwargs.pop("use_fast", True),
|
||||
**kwargs)
|
||||
return self.ctx.get_hf_processor(
|
||||
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
# Although vLLM can support more images from an infra capability
|
||||
@@ -546,13 +552,13 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
image_size = vision_config.image_size
|
||||
patch_size = vision_config.patch_size
|
||||
|
||||
assert (
|
||||
image_size %
|
||||
patch_size == 0), f"chunk size {image_size} should be multiple of "
|
||||
assert image_size % patch_size == 0, (
|
||||
f"chunk size {image_size} should be multiple of "
|
||||
)
|
||||
f"patch_size {patch_size}"
|
||||
|
||||
ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2)))
|
||||
return (image_size // patch_size)**2 // ds_ratio
|
||||
return (image_size // patch_size) ** 2 // ds_ratio
|
||||
|
||||
def get_max_num_tiles(self) -> int:
|
||||
image_processor = self.get_hf_processor().image_processor
|
||||
@@ -562,13 +568,10 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
vision_config = self.get_hf_config().vision_config
|
||||
image_size = vision_config.image_size
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
return ImageSize(height=self.get_max_num_tiles() * image_size,
|
||||
width=image_size)
|
||||
return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size)
|
||||
|
||||
|
||||
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
):
|
||||
|
||||
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -592,15 +595,16 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
vision_config = self.info.get_hf_config().vision_config
|
||||
|
||||
if processed_outputs.get("pixel_values") is not None:
|
||||
assert (
|
||||
"images" in mm_data
|
||||
), "images expected to be in mm_data when pixel_values is present"
|
||||
assert "images" in mm_data, (
|
||||
"images expected to be in mm_data when pixel_values is present"
|
||||
)
|
||||
|
||||
images = mm_data["images"]
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
"image":
|
||||
images
|
||||
}).get_items("image", ImageProcessorItems))
|
||||
parsed_images = (
|
||||
self._get_data_parser()
|
||||
.parse_mm_data({"image": images})
|
||||
.get_items("image", ImageProcessorItems)
|
||||
)
|
||||
|
||||
tile_size = vision_config.image_size
|
||||
possible_resolutions = find_supported_resolutions(
|
||||
@@ -612,20 +616,20 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
(image.size[1], image.size[0]),
|
||||
torch.tensor(possible_resolutions),
|
||||
resize_to_max_canvas=image_processor.resize_to_max_canvas,
|
||||
) for image in parsed_images
|
||||
)
|
||||
for image in parsed_images
|
||||
]
|
||||
# TODO tile height/width do not necessarily need to match
|
||||
aspect_ratios = [(image_size[0] // tile_size,
|
||||
image_size[1] // tile_size)
|
||||
for image_size in best_fit_sizes]
|
||||
aspect_ratios = [
|
||||
(image_size[0] // tile_size, image_size[1] // tile_size)
|
||||
for image_size in best_fit_sizes
|
||||
]
|
||||
patches_per_image = [
|
||||
1 if r_h * r_w == 1 else 1 + r_h * r_w
|
||||
for (r_h, r_w) in aspect_ratios
|
||||
1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
|
||||
]
|
||||
|
||||
processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios)
|
||||
processed_outputs["patches_per_image"] = torch.tensor(
|
||||
patches_per_image)
|
||||
processed_outputs["patches_per_image"] = torch.tensor(patches_per_image)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@@ -637,7 +641,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", patches_per_image),
|
||||
"image", patches_per_image
|
||||
),
|
||||
patches_per_image=MultiModalFieldConfig.batched("image"),
|
||||
aspect_ratios=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
@@ -677,7 +682,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
|
||||
|
||||
class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
@@ -694,17 +698,17 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
(target_width,
|
||||
target_height) = self.info.get_image_size_with_most_features()
|
||||
(target_width, target_height) = self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -713,8 +717,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
||||
info=Mllama4ProcessingInfo,
|
||||
dummy_inputs=Mllama4DummyInputsBuilder,
|
||||
)
|
||||
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@@ -747,24 +750,26 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||
self.config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||
self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")
|
||||
)
|
||||
else:
|
||||
self.vision_model = None
|
||||
self.multi_modal_projector = None
|
||||
self.language_model = initialize_model(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config,
|
||||
["LlamaForCausalLM"]),
|
||||
vllm_config=vllm_config.with_hf_config(
|
||||
config.text_config, ["LlamaForCausalLM"]
|
||||
),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
model_class=Llama4ForCausalLM,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[Llama4ImagePatchInputs]:
|
||||
# num_images, 1, num_chunks, channel, image_size, image_size
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is None:
|
||||
@@ -786,8 +791,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
|
||||
|
||||
self, image_input: Llama4ImagePatchInputs
|
||||
) -> MultiModalEmbeddings:
|
||||
assert self.vision_model and self.multi_modal_projector
|
||||
flat_data = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"].tolist()
|
||||
@@ -795,12 +800,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# shard image input
|
||||
if self.use_data_parallel:
|
||||
vision_embeddings_flat = run_dp_sharded_vision_model(
|
||||
flat_data, self.vision_model)
|
||||
flat_data, self.vision_model
|
||||
)
|
||||
else:
|
||||
vision_embeddings_flat = self.vision_model(flat_data)
|
||||
|
||||
vision_embeddings_flat = self.multi_modal_projector(
|
||||
vision_embeddings_flat)
|
||||
vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)
|
||||
|
||||
return [
|
||||
img.flatten(0, 1)
|
||||
@@ -828,8 +833,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
return self.language_model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return self.language_model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@@ -841,8 +847,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
prefix: str,
|
||||
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[
|
||||
str, torch.Tensor]]]:
|
||||
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
|
||||
weights1, weights2 = tee(weights, 2)
|
||||
|
||||
def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
|
||||
@@ -884,31 +889,33 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
|
||||
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
|
||||
format."""
|
||||
if name.startswith("model.") or name.startswith(
|
||||
"language_model.model."):
|
||||
renamed = name.replace("model.", "language_model.model.",
|
||||
1) if name.startswith("model.") else name
|
||||
if name.startswith("model.") or name.startswith("language_model.model."):
|
||||
renamed = (
|
||||
name.replace("model.", "language_model.model.", 1)
|
||||
if name.startswith("model.")
|
||||
else name
|
||||
)
|
||||
# Handle expert scale parameters with flat naming
|
||||
if "feed_forward.experts." in name and ("_input_scale" in name or
|
||||
"_weight_scale" in name):
|
||||
if "feed_forward.experts." in name and (
|
||||
"_input_scale" in name or "_weight_scale" in name
|
||||
):
|
||||
# Map checkpoint naming to vLLM's expected naming
|
||||
if "down_proj_input_scale" in renamed:
|
||||
return renamed.replace("down_proj_input_scale",
|
||||
"w2_input_scale")
|
||||
return renamed.replace("down_proj_input_scale", "w2_input_scale")
|
||||
elif "down_proj_weight_scale" in renamed:
|
||||
return renamed.replace("down_proj_weight_scale",
|
||||
"w2_weight_scale")
|
||||
return renamed.replace("down_proj_weight_scale", "w2_weight_scale")
|
||||
elif "gate_up_proj_input_scale" in renamed:
|
||||
return renamed.replace("gate_up_proj_input_scale",
|
||||
"w13_input_scale")
|
||||
return renamed.replace(
|
||||
"gate_up_proj_input_scale", "w13_input_scale"
|
||||
)
|
||||
elif "gate_up_proj_weight_scale" in renamed:
|
||||
return renamed.replace("gate_up_proj_weight_scale",
|
||||
"w13_weight_scale")
|
||||
return renamed.replace(
|
||||
"gate_up_proj_weight_scale", "w13_weight_scale"
|
||||
)
|
||||
return renamed
|
||||
|
||||
# Handle attention scale parameters
|
||||
elif "self_attn." in name and (".k_scale" in name
|
||||
or ".v_scale" in name):
|
||||
elif "self_attn." in name and (".k_scale" in name or ".v_scale" in name):
|
||||
if ".k_proj.k_scale" in renamed:
|
||||
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
|
||||
elif ".v_proj.v_scale" in renamed:
|
||||
@@ -919,8 +926,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return renamed
|
||||
|
||||
elif name.startswith("lm_head.weight"):
|
||||
return name.replace("lm_head.weight",
|
||||
"language_model.lm_head.weight")
|
||||
return name.replace("lm_head.weight", "language_model.lm_head.weight")
|
||||
|
||||
return name
|
||||
|
||||
@@ -943,7 +949,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return language_model_weights, other_weights
|
||||
|
||||
def _handle_expert_scale_broadcasting(
|
||||
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
|
||||
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
|
||||
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
|
||||
"""Handle expert scale parameters that need broadcasting.
|
||||
|
||||
@@ -956,12 +962,18 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
for name, weight in weights:
|
||||
# Check if this is an expert scale parameter that needs broadcasting
|
||||
if ("feed_forward.experts." in name and "scale" in name
|
||||
and ".shared_expert" not in name):
|
||||
if (
|
||||
"feed_forward.experts." in name
|
||||
and "scale" in name
|
||||
and ".shared_expert" not in name
|
||||
):
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
if (hasattr(param, 'data') and param.data.numel() > 1
|
||||
and weight.numel() == 1):
|
||||
if (
|
||||
hasattr(param, "data")
|
||||
and param.data.numel() > 1
|
||||
and weight.numel() == 1
|
||||
):
|
||||
# Broadcast single value to all experts
|
||||
param.data.fill_(weight.item())
|
||||
updated_params.add(name)
|
||||
@@ -973,10 +985,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return regular_weights, expert_scale_weights, updated_params
|
||||
|
||||
def _load_other_weights(self, other_weights: Iterable[tuple[str,
|
||||
torch.Tensor]],
|
||||
params_dict: dict,
|
||||
stacked_params_mapping: list) -> set[str]:
|
||||
def _load_other_weights(
|
||||
self,
|
||||
other_weights: Iterable[tuple[str, torch.Tensor]],
|
||||
params_dict: dict,
|
||||
stacked_params_mapping: list,
|
||||
) -> set[str]:
|
||||
"""Load non-language-model weights with stacking support."""
|
||||
updated_params = set()
|
||||
|
||||
@@ -997,16 +1011,13 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
else:
|
||||
# Use regular weight loading
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(name)
|
||||
|
||||
return updated_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
@@ -1023,8 +1034,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
updated_params: set[str] = set()
|
||||
|
||||
# Separate and rename weights
|
||||
language_model_weights, other_weights = (
|
||||
self._separate_and_rename_weights(weights))
|
||||
language_model_weights, other_weights = self._separate_and_rename_weights(
|
||||
weights
|
||||
)
|
||||
|
||||
# Skip loading vision model and projector if they're not initialized.
|
||||
if self.vision_model is None and self.multi_modal_projector is None:
|
||||
@@ -1032,8 +1044,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
# Handle expert scale parameters
|
||||
regular_weights, expert_scale_weights, updated_params_from_experts = (
|
||||
self._handle_expert_scale_broadcasting(language_model_weights,
|
||||
params_dict))
|
||||
self._handle_expert_scale_broadcasting(language_model_weights, params_dict)
|
||||
)
|
||||
updated_params.update(updated_params_from_experts)
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
@@ -1042,13 +1054,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
updated_params.update(loaded_language_model_params)
|
||||
|
||||
if expert_scale_weights:
|
||||
loaded_expert_scale_params = loader.load_weights(
|
||||
expert_scale_weights)
|
||||
loaded_expert_scale_params = loader.load_weights(expert_scale_weights)
|
||||
if loaded_expert_scale_params:
|
||||
updated_params.update(loaded_expert_scale_params)
|
||||
|
||||
updated_params.update(
|
||||
self._load_other_weights(other_weights, params_dict,
|
||||
stacked_params_mapping))
|
||||
self._load_other_weights(other_weights, params_dict, stacked_params_mapping)
|
||||
)
|
||||
|
||||
return updated_params
|
||||
|
||||
Reference in New Issue
Block a user