Enable Eagle3 speculative decoding for Pixtral (LlavaForConditionalGeneration) (#32542)
Signed-off-by: gopalsarda <gopal.sarda@servicenow.com>
This commit is contained in:
@@ -53,6 +53,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
|
SupportsEagle3,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
@@ -503,7 +504,7 @@ def init_vision_tower_for_llava(
|
|||||||
dummy_inputs=LlavaDummyInputsBuilder,
|
dummy_inputs=LlavaDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class LlavaForConditionalGeneration(
|
class LlavaForConditionalGeneration(
|
||||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
|
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
|
||||||
):
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
@@ -527,6 +528,13 @@ class LlavaForConditionalGeneration(
|
|||||||
|
|
||||||
raise ValueError("Only image modality is supported")
|
raise ValueError("Only image modality is supported")
|
||||||
|
|
||||||
|
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
|
self.get_language_model().model.aux_hidden_state_layers = layers
|
||||||
|
|
||||||
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
|
num_layers = len(self.get_language_model().model.layers)
|
||||||
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user