Enable Eagle3 speculative decoding for Mistral3ForConditionalGeneration to support eagle3 (#33939)
Signed-off-by: Akintunde Oladipo <akintunde.oladipo@servicenow.com> Signed-off-by: TundeAtSN <akintunde.oladipo@servicenow.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
@@ -408,7 +409,7 @@ def init_vision_tower_for_llava(
|
||||
dummy_inputs=Mistral3DummyInputsBuilder,
|
||||
)
|
||||
class Mistral3ForConditionalGeneration(
|
||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@@ -432,6 +433,13 @@ class Mistral3ForConditionalGeneration(
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.get_language_model().model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.get_language_model().model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user