From aa84e43ccb540dfbbd723f5b315ef7eefd732641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Delacourt?= <54138269+Flechman@users.noreply.github.com> Date: Fri, 20 Mar 2026 16:50:15 +0100 Subject: [PATCH] [Pixtral] Enable Pixtral language model support Eagle3 (#37182) Signed-off-by: remi --- vllm/model_executor/models/pixtral.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 8b1455359..eaf5843a3 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -66,9 +66,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( MultiModalEmbeddings, + SupportsEagle3, SupportsLoRA, SupportsMultiModal, SupportsPP, + supports_eagle3, ) from .module_mapping import MultiModelKeys from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix @@ -262,7 +264,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) dummy_inputs=PixtralDummyInputsBuilder, ) class PixtralForConditionalGeneration( - nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP + nn.Module, SupportsLoRA, SupportsEagle3, SupportsMultiModal, SupportsPP ): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: @@ -390,6 +392,21 @@ class PixtralForConditionalGeneration( ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) + def _require_language_model_eagle3(self) -> None: + if not supports_eagle3(self.language_model): + raise RuntimeError( + f"EAGLE-3 speculative decoding requires the language model to " + f"support EAGLE-3, but {type(self.language_model).__name__} does not." + ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self._require_language_model_eagle3() + self.language_model.set_aux_hidden_state_layers(layers) + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + self._require_language_model_eagle3() + return self.language_model.get_eagle3_aux_hidden_state_layers() + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith(("vision_encoder", "vision_tower"))