[Misc][LLaMa4] Compile LLaMa Vision Encoder (#30709)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -369,7 +369,11 @@ def llama_model_invariants(
|
||||
torch._check(positions.size()[0] == input_ids.size()[0])
|
||||
|
||||
|
||||
@support_torch_compile(shape_invariants=llama_model_invariants)
|
||||
@support_torch_compile(
|
||||
# TODO[#32068]: Investigate recompilation
|
||||
# mark_unbacked_dims={"input_ids": 0},
|
||||
shape_invariants=llama_model_invariants
|
||||
)
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -31,9 +31,11 @@ from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
get_best_fit,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -47,6 +49,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.utils import initialize_model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@@ -456,6 +459,9 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_vit
|
||||
)
|
||||
class Llama4VisionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -497,6 +503,7 @@ class Llama4VisionModel(nn.Module):
|
||||
prefix=f"{prefix}.model",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
self.vision_adapter = Llama4VisionPixelShuffleMLP(
|
||||
config,
|
||||
quant_config,
|
||||
@@ -762,18 +769,28 @@ class Llama4ForConditionalGeneration(
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
self.vision_model = Llama4VisionModel(
|
||||
config.vision_config,
|
||||
None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
set_model_tag("Llama4VisionModel", is_encoder=True),
|
||||
):
|
||||
self.vision_model = Llama4VisionModel(
|
||||
config=config.vision_config,
|
||||
quant_config=None,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(
|
||||
self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector")
|
||||
config=self.config,
|
||||
quant_config=None,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
else:
|
||||
self.vision_model = None
|
||||
@@ -883,7 +900,10 @@ class Llama4ForConditionalGeneration(
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
with (
|
||||
set_forward_context(None, self.vllm_config),
|
||||
):
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user