diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 1162217b6..08f5d45e2 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -44,6 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( MultiModalEmbeddings, + SupportsEagle3, SupportsLoRA, SupportsMultiModal, SupportsPP, @@ -408,7 +409,7 @@ def init_vision_tower_for_llava( dummy_inputs=Mistral3DummyInputsBuilder, ) class Mistral3ForConditionalGeneration( - nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP + nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3 ): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -432,6 +433,13 @@ class Mistral3ForConditionalGeneration( raise ValueError("Only image modality is supported") + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.get_language_model().model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.get_language_model().model.layers) + return (2, num_layers // 2, num_layers - 3) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__()