[Refactor] Consolidate SupportsEagle (#36063)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
54a6db827f
commit
8b346309a5
@@ -13,15 +13,15 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.models.interfaces import EagleModelMixin
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
|
||||||
class PredictableLlamaModel(nn.Module):
|
class PredictableLlamaModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = vllm_config.model_config.hf_config
|
self.config = vllm_config.model_config.hf_config
|
||||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
|
||||||
|
|
||||||
# Create minimal embed_tokens for embedding
|
# Create minimal embed_tokens for embedding
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
maybe_remap_kv_scale_name,
|
maybe_remap_kv_scale_name,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.interfaces import (
|
from vllm.model_executor.models.interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
@@ -384,7 +385,7 @@ class AfmoeDecoderLayer(nn.Module):
|
|||||||
"inputs_embeds": 0,
|
"inputs_embeds": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
class AfmoeModel(nn.Module):
|
class AfmoeModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -421,8 +422,6 @@ class AfmoeModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size
|
["hidden_states", "residual"], config.hidden_size
|
||||||
)
|
)
|
||||||
@@ -453,15 +452,14 @@ class AfmoeModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for idx, layer in enumerate(
|
for idx, layer in enumerate(
|
||||||
islice(self.layers, self.start_layer, self.end_layer)
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
):
|
):
|
||||||
if idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(
|
|
||||||
hidden_states + residual if residual is not None else hidden_states
|
|
||||||
)
|
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@@ -691,13 +689,6 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
|||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_input_ids(input_ids)
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor | None,
|
input_ids: torch.Tensor | None,
|
||||||
|
|||||||
@@ -60,7 +60,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.v1.attention.backend import AttentionType
|
from vllm.v1.attention.backend import AttentionType
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@@ -313,7 +319,7 @@ class ApertusDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class ApertusModel(nn.Module):
|
class ApertusModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -357,8 +363,6 @@ class ApertusModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size
|
["hidden_states", "residual"], config.hidden_size
|
||||||
)
|
)
|
||||||
@@ -384,13 +388,14 @@ class ApertusModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for idx, layer in enumerate(
|
for idx, layer in enumerate(
|
||||||
islice(self.layers, self.start_layer, self.end_layer)
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
):
|
):
|
||||||
if idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@@ -472,7 +477,9 @@ class ApertusModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class ApertusForCausalLM(
|
||||||
|
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||||
|
):
|
||||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||||
|
|
||||||
# LoRA specific attributes
|
# LoRA specific attributes
|
||||||
@@ -520,13 +527,6 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self.model.make_empty_intermediate_tensors
|
self.model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def _init_model(
|
def _init_model(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
|||||||
@@ -32,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@@ -170,7 +176,7 @@ class ArceeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class ArceeModel(nn.Module):
|
class ArceeModel(nn.Module, EagleModelMixin):
|
||||||
"""The transformer model backbone for Arcee (embedding layer + stacked
|
"""The transformer model backbone for Arcee (embedding layer + stacked
|
||||||
decoder blocks + final norm)."""
|
decoder blocks + final norm)."""
|
||||||
|
|
||||||
@@ -218,10 +224,6 @@ class ArceeModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
# For optional capturing of intermediate hidden states
|
|
||||||
# (not used by default)
|
|
||||||
self.aux_hidden_state_layers: tuple[int, ...] = tuple()
|
|
||||||
|
|
||||||
# Prepare factory for empty intermediate tensors
|
# Prepare factory for empty intermediate tensors
|
||||||
# (for pipeline scheduling)
|
# (for pipeline scheduling)
|
||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
@@ -253,15 +255,14 @@ class ArceeModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states: list[torch.Tensor] = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for idx, layer in enumerate(
|
for idx, layer in enumerate(
|
||||||
islice(self.layers, self.start_layer, self.end_layer)
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
):
|
):
|
||||||
if idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(
|
|
||||||
hidden_states + residual
|
|
||||||
) # capture pre-layer hidden state if needed
|
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# Send intermediate results to the next pipeline stage
|
# Send intermediate results to the next pipeline stage
|
||||||
@@ -348,7 +349,9 @@ class ArceeModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
class ArceeForCausalLM(
|
||||||
|
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||||
|
):
|
||||||
"""Arcee Model for causal language modeling, integrated with vLLM
|
"""Arcee Model for causal language modeling, integrated with vLLM
|
||||||
runtime."""
|
runtime."""
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,13 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.attention.backend import AttentionType
|
from vllm.v1.attention.backend import AttentionType
|
||||||
|
|
||||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
from .interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
WeightsMapper,
|
WeightsMapper,
|
||||||
@@ -256,7 +262,7 @@ class TransformerBlock(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class GptOssModel(nn.Module):
|
class GptOssModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -285,7 +291,6 @@ class GptOssModel(nn.Module):
|
|||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], self.config.hidden_size
|
["hidden_states", "residual"], self.config.hidden_size
|
||||||
)
|
)
|
||||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embedding(input_ids)
|
return self.embedding(input_ids)
|
||||||
@@ -309,12 +314,13 @@ class GptOssModel(nn.Module):
|
|||||||
x = intermediate_tensors["hidden_states"]
|
x = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state(
|
||||||
|
[], self.start_layer, x, residual
|
||||||
|
)
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
if i in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(x if residual is None else x + residual)
|
|
||||||
x, residual = layer(x, positions, residual)
|
x, residual = layer(x, positions, residual)
|
||||||
|
self._maybe_add_hidden_state(aux_hidden_states, i + 1, x, residual)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": x, "residual": residual})
|
return IntermediateTensors({"hidden_states": x, "residual": residual})
|
||||||
x, _ = self.norm(x, residual)
|
x, _ = self.norm(x, residual)
|
||||||
@@ -1141,7 +1147,9 @@ class GptOssModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
class GptOssForCausalLM(
|
||||||
|
nn.Module, SupportsPP, SupportsEagle, SupportsEagle3, SupportsLoRA
|
||||||
|
):
|
||||||
is_3d_moe_weight: bool = True
|
is_3d_moe_weight: bool = True
|
||||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||||
|
|
||||||
@@ -1197,13 +1205,6 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
|||||||
self.model.make_empty_intermediate_tensors
|
self.model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_input_ids(input_ids)
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,14 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.v1.attention.backend import AttentionType
|
from vllm.v1.attention.backend import AttentionType
|
||||||
|
|
||||||
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
|
from .interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
|
MixtureOfExperts,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@@ -586,7 +593,7 @@ class HunYuanDecoderLayer(nn.Module):
|
|||||||
"inputs_embeds": 0,
|
"inputs_embeds": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
class HunYuanModel(nn.Module):
|
class HunYuanModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -629,7 +636,6 @@ class HunYuanModel(nn.Module):
|
|||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
@@ -654,13 +660,10 @@ class HunYuanModel(nn.Module):
|
|||||||
|
|
||||||
cla_factor = _get_cla_factor(self.config)
|
cla_factor = _get_cla_factor(self.config)
|
||||||
prev_kv_states = None
|
prev_kv_states = None
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for i, layer in enumerate(
|
for i, layer in enumerate(
|
||||||
islice(self.layers, self.start_layer, self.end_layer)
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
):
|
):
|
||||||
if i in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
|
||||||
|
|
||||||
hidden_states, residual, kv_states = layer(
|
hidden_states, residual, kv_states = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -673,6 +676,10 @@ class HunYuanModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
prev_kv_states = None
|
prev_kv_states = None
|
||||||
|
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, i + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
{"hidden_states": hidden_states, "residual": residual}
|
{"hidden_states": hidden_states, "residual": residual}
|
||||||
@@ -904,7 +911,9 @@ class HunYuanModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
class HunyuanV1ModelBase(
|
||||||
|
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||||
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@@ -943,13 +952,6 @@ class HunyuanV1ModelBase(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
|||||||
else:
|
else:
|
||||||
self.lm_head = PPMissingLayer()
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor | None,
|
input_ids: torch.Tensor | None,
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
@@ -801,6 +802,7 @@ class HunYuanVLForConditionalGeneration(
|
|||||||
SupportsPP,
|
SupportsPP,
|
||||||
SupportsQuant,
|
SupportsQuant,
|
||||||
SupportsXDRoPE,
|
SupportsXDRoPE,
|
||||||
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
):
|
):
|
||||||
# To ensure correct weight loading and mapping.
|
# To ensure correct weight loading and mapping.
|
||||||
@@ -988,13 +990,6 @@ class HunYuanVLForConditionalGeneration(
|
|||||||
multimodal_embeddings += tuple(image_embeddings)
|
multimodal_embeddings += tuple(image_embeddings)
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.language_model.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.language_model.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor | None,
|
input_ids: torch.Tensor | None,
|
||||||
|
|||||||
@@ -1273,6 +1273,25 @@ def supports_any_eagle(
|
|||||||
return supports_eagle(model) or supports_eagle3(model)
|
return supports_eagle(model) or supports_eagle3(model)
|
||||||
|
|
||||||
|
|
||||||
|
class EagleModelMixin:
|
||||||
|
aux_hidden_state_layers: tuple[int, ...] = ()
|
||||||
|
|
||||||
|
def _set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
|
self.aux_hidden_state_layers = layers
|
||||||
|
|
||||||
|
def _maybe_add_hidden_state(
|
||||||
|
self,
|
||||||
|
aux_hidden_states: list[torch.Tensor],
|
||||||
|
layer_idx: int,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
if layer_idx in self.aux_hidden_state_layers:
|
||||||
|
value = hidden_states + residual if residual is not None else hidden_states
|
||||||
|
aux_hidden_states.append(value)
|
||||||
|
return aux_hidden_states
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class SupportsEagle(SupportsEagleBase, Protocol):
|
class SupportsEagle(SupportsEagleBase, Protocol):
|
||||||
"""The interface required for models that support
|
"""The interface required for models that support
|
||||||
@@ -1320,24 +1339,48 @@ class SupportsEagle3(SupportsEagleBase, Protocol):
|
|||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
"""
|
"""
|
||||||
Set which layers should output auxiliary
|
Set which layers should output auxiliary hidden states for EAGLE-3.
|
||||||
hidden states for EAGLE-3.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layers: Tuple of layer indices that should output auxiliary
|
layers: Tuple of layer indices that should output auxiliary
|
||||||
hidden states.
|
hidden states.
|
||||||
"""
|
"""
|
||||||
...
|
parent_ref = self
|
||||||
|
if hasattr(self, "get_language_model"):
|
||||||
|
parent_ref = self.get_language_model()
|
||||||
|
elif hasattr(self, "language_model"):
|
||||||
|
parent_ref = self.language_model
|
||||||
|
assert hasattr(parent_ref, "model"), (
|
||||||
|
"Model instance must have 'model' attribute to set number of layers"
|
||||||
|
)
|
||||||
|
assert isinstance(parent_ref.model, EagleModelMixin), (
|
||||||
|
"Model instance must inherit from EagleModelMixin to set auxiliary layers"
|
||||||
|
)
|
||||||
|
parent_ref.model._set_aux_hidden_state_layers(layers)
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
"""
|
"""
|
||||||
Get the layer indices that should output auxiliary hidden states
|
Get the default layer indices that should output auxiliary hidden states
|
||||||
for EAGLE-3.
|
for EAGLE-3 for this model. Models can override this method to provide
|
||||||
|
different default layers based on their architecture, but it is encouraged
|
||||||
|
to instead include the layer specification in the model's config if possible.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of layer indices for auxiliary hidden state outputs.
|
Tuple of layer indices for auxiliary hidden state outputs.
|
||||||
"""
|
"""
|
||||||
...
|
parent_ref = self
|
||||||
|
if hasattr(self, "get_language_model"):
|
||||||
|
parent_ref = self.get_language_model()
|
||||||
|
elif hasattr(self, "language_model"):
|
||||||
|
parent_ref = self.language_model
|
||||||
|
assert hasattr(parent_ref, "model"), (
|
||||||
|
"Model instance must have 'model' attribute to get number of layers"
|
||||||
|
)
|
||||||
|
assert hasattr(parent_ref.model, "layers"), (
|
||||||
|
"Model instance must have 'layers' attribute to get number of layers"
|
||||||
|
)
|
||||||
|
num_layers = len(parent_ref.model.layers)
|
||||||
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ from vllm.v1.attention.backend import AttentionType
|
|||||||
|
|
||||||
from .adapters import as_embedding_model, as_seq_cls_model
|
from .adapters import as_embedding_model, as_seq_cls_model
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
SupportsEagle,
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
@@ -351,7 +352,7 @@ def llama_model_invariants(
|
|||||||
# mark_unbacked_dims={"input_ids": 0},
|
# mark_unbacked_dims={"input_ids": 0},
|
||||||
shape_invariants=llama_model_invariants
|
shape_invariants=llama_model_invariants
|
||||||
)
|
)
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -389,8 +390,6 @@ class LlamaModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size
|
["hidden_states", "residual"], config.hidden_size
|
||||||
)
|
)
|
||||||
@@ -417,15 +416,16 @@ class LlamaModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for idx, layer in enumerate(
|
for idx, layer in enumerate(
|
||||||
islice(self.layers, self.start_layer, self.end_layer)
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
):
|
):
|
||||||
if idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, residual, **extra_layer_kwargs
|
positions, hidden_states, residual, **extra_layer_kwargs
|
||||||
)
|
)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@@ -556,18 +556,6 @@ class LlamaForCausalLM(
|
|||||||
self.model.make_empty_intermediate_tensors
|
self.model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
"""Override to return default layers for Llama
|
|
||||||
|
|
||||||
Note: The GPU model runner will override this with layers from
|
|
||||||
the speculative config if available, providing dynamic configuration.
|
|
||||||
"""
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def _init_model(
|
def _init_model(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
@@ -503,7 +504,12 @@ def init_vision_tower_for_llava(
|
|||||||
dummy_inputs=LlavaDummyInputsBuilder,
|
dummy_inputs=LlavaDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class LlavaForConditionalGeneration(
|
class LlavaForConditionalGeneration(
|
||||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
|
nn.Module,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsMultiModal,
|
||||||
|
SupportsPP,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
):
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
@@ -527,13 +533,6 @@ class LlavaForConditionalGeneration(
|
|||||||
|
|
||||||
raise ValueError("Only image modality is supported")
|
raise ValueError("Only image modality is supported")
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.get_language_model().model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.get_language_model().model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -682,13 +682,6 @@ class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
|||||||
self.model.make_empty_intermediate_tensors
|
self.model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_input_ids(input_ids)
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,13 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
from .interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
@@ -391,7 +397,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class MiniCPMModel(nn.Module):
|
class MiniCPMModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -413,8 +419,6 @@ class MiniCPMModel(nn.Module):
|
|||||||
self._init_layers(prefix, config, cache_config, quant_config)
|
self._init_layers(prefix, config, cache_config, quant_config)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], self.config.hidden_size
|
["hidden_states", "residual"], self.config.hidden_size
|
||||||
)
|
)
|
||||||
@@ -455,19 +459,18 @@ class MiniCPMModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for idx, layer in enumerate(
|
for idx, layer in enumerate(
|
||||||
islice(self.layers, self.start_layer, self.end_layer)
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
):
|
):
|
||||||
if idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(
|
|
||||||
hidden_states + residual if residual is not None else hidden_states
|
|
||||||
)
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@@ -550,7 +553,9 @@ class MiniCPMModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
class MiniCPMForCausalLM(
|
||||||
|
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||||
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@@ -611,13 +616,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
|||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_input_ids(input_ids)
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor | None,
|
input_ids: torch.Tensor | None,
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
@@ -409,7 +410,12 @@ def init_vision_tower_for_llava(
|
|||||||
dummy_inputs=Mistral3DummyInputsBuilder,
|
dummy_inputs=Mistral3DummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class Mistral3ForConditionalGeneration(
|
class Mistral3ForConditionalGeneration(
|
||||||
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
|
nn.Module,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsMultiModal,
|
||||||
|
SupportsPP,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
):
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
@@ -433,13 +439,6 @@ class Mistral3ForConditionalGeneration(
|
|||||||
|
|
||||||
raise ValueError("Only image modality is supported")
|
raise ValueError("Only image modality is supported")
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.get_language_model().model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.get_language_model().model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -798,20 +798,16 @@ class Llama4ForConditionalGeneration(
|
|||||||
self.num_moe_layers = len(self.moe_layers)
|
self.num_moe_layers = len(self.moe_layers)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
"""Set which layers should output auxiliary hidden states for EAGLE3."""
|
|
||||||
# Delegate to underlying language model (Llama4ForCausalLM)
|
# Delegate to underlying language model (Llama4ForCausalLM)
|
||||||
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
|
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
|
||||||
self.language_model.set_aux_hidden_state_layers(layers)
|
self.language_model.set_aux_hidden_state_layers(layers)
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
"""Get the layer indices for auxiliary hidden state outputs.
|
|
||||||
|
|
||||||
Note: The GPU model runner will override this with layers from
|
|
||||||
the speculative config if available, providing dynamic configuration.
|
|
||||||
"""
|
|
||||||
# Delegate to underlying language model (Llama4ForCausalLM)
|
# Delegate to underlying language model (Llama4ForCausalLM)
|
||||||
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
|
assert hasattr(
|
||||||
return self.language_model.get_eagle3_aux_hidden_state_layers()
|
self.language_model, "get_eagle3_default_aux_hidden_state_layers"
|
||||||
|
)
|
||||||
|
return self.language_model.get_eagle3_default_aux_hidden_state_layers()
|
||||||
|
|
||||||
def set_eplb_state(
|
def set_eplb_state(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -62,7 +62,13 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta
|
from vllm.transformers_utils.config import is_interleaved, set_default_rope_theta
|
||||||
from vllm.v1.attention.backend import AttentionType
|
from vllm.v1.attention.backend import AttentionType
|
||||||
|
|
||||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
from .interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@@ -349,7 +355,7 @@ def qwen_2_model_invariants(
|
|||||||
},
|
},
|
||||||
shape_invariants=qwen_2_model_invariants,
|
shape_invariants=qwen_2_model_invariants,
|
||||||
)
|
)
|
||||||
class Qwen2Model(nn.Module):
|
class Qwen2Model(nn.Module, EagleModelMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -410,8 +416,6 @@ class Qwen2Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@@ -433,13 +437,14 @@ class Qwen2Model(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for idx, layer in enumerate(
|
for idx, layer in enumerate(
|
||||||
islice(self.layers, self.start_layer, self.end_layer)
|
islice(self.layers, self.start_layer, self.end_layer)
|
||||||
):
|
):
|
||||||
if idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@@ -519,7 +524,9 @@ class Qwen2Model(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
class Qwen2ForCausalLM(
|
||||||
|
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||||
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@@ -566,13 +573,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
|||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_input_ids(input_ids)
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor | None,
|
input_ids: torch.Tensor | None,
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
|||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
@@ -1000,6 +1001,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
SupportsQuant,
|
SupportsQuant,
|
||||||
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsMultiModalPruning,
|
SupportsMultiModalPruning,
|
||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
@@ -1143,13 +1145,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
self.language_model.make_empty_intermediate_tensors
|
self.language_model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.language_model.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.language_model.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object
|
self, **kwargs: object
|
||||||
) -> Qwen2_5_VLImageInputs | None:
|
) -> Qwen2_5_VLImageInputs | None:
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.transformers_utils.config import set_default_rope_theta
|
from vllm.transformers_utils.config import set_default_rope_theta
|
||||||
from vllm.v1.attention.backend import AttentionType
|
from vllm.v1.attention.backend import AttentionType
|
||||||
|
|
||||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||||
from .qwen2 import Qwen2MLP as Qwen3MLP
|
from .qwen2 import Qwen2MLP as Qwen3MLP
|
||||||
from .qwen2 import Qwen2Model
|
from .qwen2 import Qwen2Model
|
||||||
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
|
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix
|
||||||
@@ -258,7 +258,9 @@ class Qwen3Model(Qwen2Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
class Qwen3ForCausalLM(
|
||||||
|
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
|
||||||
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@@ -307,13 +309,6 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
|||||||
self.model.make_empty_intermediate_tensors
|
self.model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_input_ids(input_ids)
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,14 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
|
from .interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
|
MixtureOfExperts,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsLoRA,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@@ -427,7 +434,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class Qwen3MoeModel(nn.Module):
|
class Qwen3MoeModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -461,8 +468,6 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size
|
["hidden_states", "residual"], config.hidden_size
|
||||||
)
|
)
|
||||||
# Track layers for auxiliary hidden state outputs (EAGLE3)
|
|
||||||
self.aux_hidden_state_layers: tuple[int, ...] = ()
|
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
@@ -485,18 +490,17 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state(
|
||||||
|
[], self.start_layer, hidden_states, residual
|
||||||
|
)
|
||||||
for layer_idx, layer in enumerate(
|
for layer_idx, layer in enumerate(
|
||||||
islice(self.layers, self.start_layer, self.end_layer),
|
islice(self.layers, self.start_layer, self.end_layer),
|
||||||
start=self.start_layer,
|
start=self.start_layer,
|
||||||
):
|
):
|
||||||
# Collect auxiliary hidden states if specified
|
|
||||||
if layer_idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_state = (
|
|
||||||
hidden_states + residual if residual is not None else hidden_states
|
|
||||||
)
|
|
||||||
aux_hidden_states.append(aux_hidden_state)
|
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, layer_idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@@ -666,7 +670,7 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen3MoeForCausalLM(
|
class Qwen3MoeForCausalLM(
|
||||||
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
|
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle, SupportsEagle3, MixtureOfExperts
|
||||||
):
|
):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
@@ -751,13 +755,6 @@ class Qwen3MoeForCausalLM(
|
|||||||
moe.n_redundant_experts = self.num_redundant_experts
|
moe.n_redundant_experts = self.num_redundant_experts
|
||||||
moe.experts.update_expert_map()
|
moe.experts.update_expert_map()
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_input_ids(input_ids)
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ from vllm.utils.math_utils import round_up
|
|||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
@@ -1275,13 +1276,10 @@ class Qwen3LLMModel(Qwen3Model):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for layer_idx, layer in islice(
|
for layer_idx, layer in islice(
|
||||||
enumerate(self.layers), self.start_layer, self.end_layer
|
enumerate(self.layers), self.start_layer, self.end_layer
|
||||||
):
|
):
|
||||||
if layer_idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1295,6 +1293,9 @@ class Qwen3LLMModel(Qwen3Model):
|
|||||||
hidden_states
|
hidden_states
|
||||||
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
|
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
|
||||||
)
|
)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, layer_idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@@ -1351,6 +1352,7 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
|
SupportsEagle,
|
||||||
SupportsEagle3,
|
SupportsEagle3,
|
||||||
SupportsMultiModalPruning,
|
SupportsMultiModalPruning,
|
||||||
):
|
):
|
||||||
@@ -1449,13 +1451,6 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
self.language_model.make_empty_intermediate_tensors
|
self.language_model.make_empty_intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
|
||||||
self.language_model.model.aux_hidden_state_layers = layers
|
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
|
||||||
num_layers = len(self.language_model.model.layers)
|
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
|
||||||
|
|
||||||
def _get_deepstack_input_embeds(
|
def _get_deepstack_input_embeds(
|
||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
|||||||
@@ -102,19 +102,17 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state(
|
||||||
|
[], self.start_layer, hidden_states, residual
|
||||||
|
)
|
||||||
for layer_idx, layer in islice(
|
for layer_idx, layer in islice(
|
||||||
enumerate(self.layers), self.start_layer, self.end_layer
|
enumerate(self.layers), self.start_layer, self.end_layer
|
||||||
):
|
):
|
||||||
if layer_idx in self.aux_hidden_state_layers:
|
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
|
|
||||||
if deepstack_input_embeds is not None and layer_idx in range(
|
if deepstack_input_embeds is not None and layer_idx in range(
|
||||||
0, len(deepstack_input_embeds)
|
0, len(deepstack_input_embeds)
|
||||||
):
|
):
|
||||||
@@ -123,6 +121,10 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|||||||
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
|
+ deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, layer_idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
{"hidden_states": hidden_states, "residual": residual}
|
{"hidden_states": hidden_states, "residual": residual}
|
||||||
|
|||||||
@@ -31,7 +31,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import SupportsPP
|
from vllm.model_executor.models.interfaces import (
|
||||||
|
EagleModelMixin,
|
||||||
|
SupportsEagle,
|
||||||
|
SupportsEagle3,
|
||||||
|
SupportsPP,
|
||||||
|
)
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
@@ -274,7 +279,7 @@ class StepDecoderLayer(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class StepDecoderModel(nn.Module):
|
class StepDecoderModel(nn.Module, EagleModelMixin):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
@@ -303,9 +308,6 @@ class StepDecoderModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer()
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
self.aux_hidden_state_layers: tuple[int, ...] = getattr(
|
|
||||||
config, "aux_hidden_state_layers", ()
|
|
||||||
)
|
|
||||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"],
|
["hidden_states", "residual"],
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
@@ -333,14 +335,12 @@ class StepDecoderModel(nn.Module):
|
|||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
aux_hidden_states = []
|
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||||
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
|
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
|
||||||
if idx in self.aux_hidden_state_layers:
|
|
||||||
if residual is None:
|
|
||||||
aux_hidden_states.append(hidden_states)
|
|
||||||
else:
|
|
||||||
aux_hidden_states.append(hidden_states + residual)
|
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
self._maybe_add_hidden_state(
|
||||||
|
aux_hidden_states, idx + 1, hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors(
|
return IntermediateTensors(
|
||||||
@@ -353,7 +353,7 @@ class StepDecoderModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Step1ForCausalLM(nn.Module, SupportsPP):
|
class Step1ForCausalLM(nn.Module, SupportsPP, SupportsEagle, SupportsEagle3):
|
||||||
packed_modules_mapping = STEP_PACKED_MODULES_MAPPING
|
packed_modules_mapping = STEP_PACKED_MODULES_MAPPING
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
|||||||
@@ -618,6 +618,6 @@ class Base(
|
|||||||
# Ensure that the capture hooks are installed before dynamo traces the model
|
# Ensure that the capture hooks are installed before dynamo traces the model
|
||||||
maybe_install_capturing_hooks(self.model)
|
maybe_install_capturing_hooks(self.model)
|
||||||
|
|
||||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
def get_eagle3_default_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
num_layers = self.text_config.num_hidden_layers
|
num_layers = self.text_config.num_hidden_layers
|
||||||
return (2, num_layers // 2, num_layers - 3)
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def set_eagle3_aux_hidden_state_layers(
|
|||||||
if aux_layers:
|
if aux_layers:
|
||||||
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
|
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
|
||||||
else:
|
else:
|
||||||
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers()
|
aux_layers = eagle3_model.get_eagle3_default_aux_hidden_state_layers()
|
||||||
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
|
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
|
||||||
eagle3_model.set_aux_hidden_state_layers(aux_layers)
|
eagle3_model.set_aux_hidden_state_layers(aux_layers)
|
||||||
|
|
||||||
|
|||||||
@@ -4556,7 +4556,9 @@ class GPUModelRunner(
|
|||||||
aux_layers,
|
aux_layers,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
|
aux_layers = (
|
||||||
|
self.model.get_eagle3_default_aux_hidden_state_layers()
|
||||||
|
)
|
||||||
|
|
||||||
self.model.set_aux_hidden_state_layers(aux_layers)
|
self.model.set_aux_hidden_state_layers(aux_layers)
|
||||||
time_after_load = time.perf_counter()
|
time_after_load = time.perf_counter()
|
||||||
|
|||||||
Reference in New Issue
Block a user