[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2025-05-23 03:05:44 +01:00
committed by GitHub
parent 04eb88dc80
commit c6b636f9fb
16 changed files with 59 additions and 135 deletions

View File

@@ -130,13 +130,15 @@ class LlamaModel(nn.Module):
class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=start_layer_id)
start_layer_id=target_layer_num)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,

View File

@@ -175,13 +175,15 @@ class LlamaModel(nn.Module):
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = LlamaModel(vllm_config=vllm_config,
start_layer_id=start_layer_id,
prefix="model")
prefix="model",
start_layer_id=target_layer_num)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.lm_head = ParallelLMHead(
@@ -193,8 +195,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
scale=logit_scale)
self.draft_id_to_target_id = nn.Parameter(
torch.zeros((self.config.draft_vocab_size),
dtype=torch.long).type(torch.LongTensor),
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
requires_grad=False,
)

View File

@@ -51,10 +51,7 @@ class Medusa(nn.Module):
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
if hasattr(vllm_config, 'draft_model_config'):
config = vllm_config.draft_model_config.hf_config
else:
config = vllm_config.model_config.hf_config
config = vllm_config.speculative_config.draft_model_config.hf_config
super().__init__()
self.config = config
self.blocks = nn.ModuleList([