[Pixtral] Enable Pixtral language model support Eagle3 (#37182)
Signed-off-by: remi <remi@mistral.ai>
This commit is contained in:
@@ -66,9 +66,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
supports_eagle3,
|
||||
)
|
||||
from .module_mapping import MultiModelKeys
|
||||
from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix
|
||||
@@ -262,7 +264,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
|
||||
dummy_inputs=PixtralDummyInputsBuilder,
|
||||
)
|
||||
class PixtralForConditionalGeneration(
|
||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
nn.Module, SupportsLoRA, SupportsEagle3, SupportsMultiModal, SupportsPP
|
||||
):
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
@@ -390,6 +392,21 @@ class PixtralForConditionalGeneration(
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def _require_language_model_eagle3(self) -> None:
|
||||
if not supports_eagle3(self.language_model):
|
||||
raise RuntimeError(
|
||||
f"EAGLE-3 speculative decoding requires the language model to "
|
||||
f"support EAGLE-3, but {type(self.language_model).__name__} does not."
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self._require_language_model_eagle3()
|
||||
self.language_model.set_aux_hidden_state_layers(layers)
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
self._require_language_model_eagle3()
|
||||
return self.language_model.get_eagle3_aux_hidden_state_layers()
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith(("vision_encoder", "vision_tower"))
|
||||
|
||||
Reference in New Issue
Block a user