[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -130,13 +130,15 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix="model",
|
||||
start_layer_id=start_layer_id)
|
||||
start_layer_id=target_layer_num)
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
|
||||
@@ -175,13 +175,15 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
target_layer_num = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
start_layer_id=start_layer_id,
|
||||
prefix="model")
|
||||
prefix="model",
|
||||
start_layer_id=target_layer_num)
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.lm_head = ParallelLMHead(
|
||||
@@ -193,8 +195,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
|
||||
scale=logit_scale)
|
||||
self.draft_id_to_target_id = nn.Parameter(
|
||||
torch.zeros((self.config.draft_vocab_size),
|
||||
dtype=torch.long).type(torch.LongTensor),
|
||||
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -51,10 +51,7 @@ class Medusa(nn.Module):
|
||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
if hasattr(vllm_config, 'draft_model_config'):
|
||||
config = vllm_config.draft_model_config.hf_config
|
||||
else:
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.blocks = nn.ModuleList([
|
||||
|
||||
Reference in New Issue
Block a user