Support llama3 eagle3 head with llama4 verifier (#25961)
Signed-off-by: rahul-tuli <rtuli@redhat.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
@@ -604,6 +604,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
|||||||
self.model.aux_hidden_state_layers = layers
|
self.model.aux_hidden_state_layers = layers
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
|
"""Override to return default layers for Llama
|
||||||
|
|
||||||
|
Note: The GPU model runner will override this with layers from
|
||||||
|
the speculative config if available, providing dynamic configuration.
|
||||||
|
"""
|
||||||
num_layers = len(self.model.layers)
|
num_layers = len(self.model.layers)
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||||
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
@@ -241,7 +242,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
|
is_multimodal: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -64,7 +64,12 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import (
|
||||||
|
MultiModalEmbeddings,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsMultiModal,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .llama4 import Llama4ForCausalLM
|
from .llama4 import Llama4ForCausalLM
|
||||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
||||||
from .vision import run_dp_sharded_vision_model
|
from .vision import run_dp_sharded_vision_model
|
||||||
@@ -717,7 +722,9 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
|
|||||||
info=Mllama4ProcessingInfo,
|
info=Mllama4ProcessingInfo,
|
||||||
dummy_inputs=Mllama4DummyInputsBuilder,
|
dummy_inputs=Mllama4DummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
class Llama4ForConditionalGeneration(
|
||||||
|
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
|
||||||
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
@@ -767,6 +774,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.language_model.make_empty_intermediate_tensors
|
self.language_model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
|
"""Set which layers should output auxiliary hidden states for EAGLE3."""
|
||||||
|
# Delegate to underlying language model (Llama4ForCausalLM)
|
||||||
|
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
|
||||||
|
self.language_model.set_aux_hidden_state_layers(layers)
|
||||||
|
|
||||||
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
|
"""Get the layer indices for auxiliary hidden state outputs.
|
||||||
|
|
||||||
|
Note: The GPU model runner will override this with layers from
|
||||||
|
the speculative config if available, providing dynamic configuration.
|
||||||
|
"""
|
||||||
|
# Delegate to underlying language model (Llama4ForCausalLM)
|
||||||
|
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
|
||||||
|
return self.language_model.get_eagle3_aux_hidden_state_layers()
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object
|
self, **kwargs: object
|
||||||
) -> Optional[Llama4ImagePatchInputs]:
|
) -> Optional[Llama4ImagePatchInputs]:
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
|
|||||||
- draft_vocab_size: Size of the draft model's vocabulary
|
- draft_vocab_size: Size of the draft model's vocabulary
|
||||||
- target_hidden_size: Hidden size of the target model
|
- target_hidden_size: Hidden size of the target model
|
||||||
- norm_before_residual: Whether to apply norm before residual connection
|
- norm_before_residual: Whether to apply norm before residual connection
|
||||||
|
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
|
||||||
|
model to use as auxiliary inputs for the Eagle3 drafter. These layers
|
||||||
|
provide intermediate hidden states that help the drafter make better
|
||||||
|
predictions. This is the standard field used in Eagle3 checkpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
|
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
|
||||||
@@ -28,3 +32,7 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
|
|||||||
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
|
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
|
||||||
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
|
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
|
||||||
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
|
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
|
||||||
|
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
|
||||||
|
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
|
||||||
|
"eagle_aux_hidden_state_layer_ids"
|
||||||
|
]
|
||||||
|
|||||||
@@ -2943,15 +2943,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logger.info("Loading drafter model...")
|
logger.info("Loading drafter model...")
|
||||||
self.drafter.load_model(self.model)
|
self.drafter.load_model(self.model)
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
if supports_eagle3(self.model):
|
if not supports_eagle3(self.model):
|
||||||
self.model.set_aux_hidden_state_layers(
|
|
||||||
self.model.get_eagle3_aux_hidden_state_layers()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Model does not support EAGLE3 interface but "
|
"Model does not support EAGLE3 interface but "
|
||||||
"aux_hidden_state_outputs was requested"
|
"aux_hidden_state_outputs was requested"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Try to get auxiliary layers from speculative config,
|
||||||
|
# otherwise use model's default layers
|
||||||
|
aux_layers = self._get_eagle3_aux_layers_from_config()
|
||||||
|
if aux_layers:
|
||||||
|
logger.info(
|
||||||
|
"Using auxiliary layers from speculative config: %s",
|
||||||
|
aux_layers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
|
||||||
|
|
||||||
|
self.model.set_aux_hidden_state_layers(aux_layers)
|
||||||
time_after_load = time.perf_counter()
|
time_after_load = time.perf_counter()
|
||||||
self.model_memory_usage = m.consumed_memory
|
self.model_memory_usage = m.consumed_memory
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -3006,6 +3015,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
|
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
|
||||||
|
"""Extract Eagle3 auxiliary layer indices from speculative config.
|
||||||
|
|
||||||
|
These indices specify which hidden states from the base model should
|
||||||
|
be used as auxiliary inputs for the Eagle3 drafter model during
|
||||||
|
speculative decoding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of layer indices if found in draft model config,
|
||||||
|
None otherwise.
|
||||||
|
"""
|
||||||
|
if not (self.speculative_config and self.speculative_config.draft_model_config):
|
||||||
|
return None
|
||||||
|
|
||||||
|
hf_config = self.speculative_config.draft_model_config.hf_config
|
||||||
|
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
|
||||||
|
if layer_ids and isinstance(layer_ids, (list, tuple)):
|
||||||
|
return tuple(layer_ids)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def reload_weights(self) -> None:
|
def reload_weights(self) -> None:
|
||||||
assert getattr(self, "model", None) is not None, (
|
assert getattr(self, "model", None) is not None, (
|
||||||
"Cannot reload weights before model is loaded."
|
"Cannot reload weights before model is loaded."
|
||||||
|
|||||||
Reference in New Issue
Block a user