[Bugfix] Fix EAGLE vocab embedding construction for Llama 70B (#19033)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
committed by
GitHub
parent
c8134bea15
commit
3465b87ef8
@@ -55,13 +55,11 @@ class LlamaModel(nn.Module):
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
# if PP disabled then draft will share embed with target
|
||||
if get_pp_group().world_size > 1:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
@@ -164,4 +162,4 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
if "lm_head" not in name:
|
||||
name = "model." + name
|
||||
model_weights[name] = loaded_weight
|
||||
return loader.load_weights(model_weights.items())
|
||||
loader.load_weights(model_weights.items())
|
||||
|
||||
Reference in New Issue
Block a user