Support llama3 eagle3 head with llama4 verifier (#25961)

Signed-off-by: rahul-tuli <rtuli@redhat.com>
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli
2025-10-06 23:26:08 +05:30
committed by GitHub
parent 20db99cc69
commit 05f6846ede
5 changed files with 83 additions and 8 deletions

View File

@@ -64,7 +64,12 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .vision import run_dp_sharded_vision_model
@@ -717,7 +722,9 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -767,6 +774,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
self.language_model.set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Get the layer indices for auxiliary hidden state outputs.
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
return self.language_model.get_eagle3_aux_hidden_state_layers()
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Llama4ImagePatchInputs]: