[Bugfix] Fix EAGLE vocab embedding construction for Llama 70B (#19033)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
committed by
GitHub
parent
c8134bea15
commit
3465b87ef8
@@ -10,7 +10,6 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
@@ -95,13 +94,11 @@ class LlamaModel(nn.Module):
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
# if PP disabled then draft will share embed with target
|
||||
if get_pp_group().world_size > 1:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
@@ -240,6 +237,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
model_weights = {}
|
||||
includes_draft_id_mapping = False
|
||||
includes_embed_tokens = False
|
||||
for name, loaded_weight in weights:
|
||||
if "t2d" in name:
|
||||
continue
|
||||
@@ -248,12 +246,18 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
includes_draft_id_mapping = True
|
||||
elif "lm_head" not in name:
|
||||
name = "model." + name
|
||||
if "embed_tokens" in name:
|
||||
includes_embed_tokens = True
|
||||
model_weights[name] = loaded_weight
|
||||
|
||||
skip_substrs = []
|
||||
if not includes_draft_id_mapping:
|
||||
skip_substrs.append("draft_id_to_target_id")
|
||||
if not includes_embed_tokens:
|
||||
skip_substrs.append("embed_tokens")
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=None,
|
||||
skip_substrs=["draft_id_to_target_id"] \
|
||||
if not includes_draft_id_mapping else None,
|
||||
skip_substrs=skip_substrs,
|
||||
)
|
||||
loader.load_weights(model_weights.items())
|
||||
|
||||
Reference in New Issue
Block a user