[Model] vLLM v1 supports Medusa (#17956)
Signed-off-by: lisiqi23 <lisiqi23@xiaomi.com> Signed-off-by: skylee-01 <497627264@qq.com> Co-authored-by: lisiqi23 <lisiqi23@xiaomi.com>
This commit is contained in:
@@ -1324,19 +1324,22 @@ class EngineArgs:
|
|||||||
# Only Ngram speculative decoding so far.
|
# Only Ngram speculative decoding so far.
|
||||||
is_ngram_enabled = False
|
is_ngram_enabled = False
|
||||||
is_eagle_enabled = False
|
is_eagle_enabled = False
|
||||||
|
is_medusa_enabled = False
|
||||||
if self.speculative_config is not None:
|
if self.speculative_config is not None:
|
||||||
# This is supported but experimental (handled below).
|
# This is supported but experimental (handled below).
|
||||||
speculative_method = self.speculative_config.get("method")
|
speculative_method = self.speculative_config.get("method")
|
||||||
if speculative_method:
|
if speculative_method:
|
||||||
if speculative_method in ("ngram", "[ngram]"):
|
if speculative_method in ("ngram", "[ngram]"):
|
||||||
is_ngram_enabled = True
|
is_ngram_enabled = True
|
||||||
|
elif speculative_method == "medusa":
|
||||||
|
is_medusa_enabled = True
|
||||||
elif speculative_method in ("eagle", "eagle3"):
|
elif speculative_method in ("eagle", "eagle3"):
|
||||||
is_eagle_enabled = True
|
is_eagle_enabled = True
|
||||||
else:
|
else:
|
||||||
speculative_model = self.speculative_config.get("model")
|
speculative_model = self.speculative_config.get("model")
|
||||||
if speculative_model in ("ngram", "[ngram]"):
|
if speculative_model in ("ngram", "[ngram]"):
|
||||||
is_ngram_enabled = True
|
is_ngram_enabled = True
|
||||||
if not (is_ngram_enabled or is_eagle_enabled):
|
if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled):
|
||||||
# Other speculative decoding methods are not supported yet.
|
# Other speculative decoding methods are not supported yet.
|
||||||
_raise_or_fallback(feature_name="Speculative Decoding",
|
_raise_or_fallback(feature_name="Speculative Decoding",
|
||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
|
|||||||
@@ -51,7 +51,10 @@ class Medusa(nn.Module):
|
|||||||
needs to have truncated_vocab_size (=k) as an attribute."""
|
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
if hasattr(vllm_config, 'draft_model_config'):
|
||||||
|
config = vllm_config.draft_model_config.hf_config
|
||||||
|
else:
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
|
|||||||
74
vllm/v1/spec_decode/medusa.py
Normal file
74
vllm/v1/spec_decode/medusa.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
from vllm.model_executor.models.medusa import Medusa
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
# Initialize logger
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaProposer:
|
||||||
|
"""
|
||||||
|
Medusa proposer class for generating token sequences
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
# Save config parameters
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.device = device
|
||||||
|
self.max_num_tokens = (
|
||||||
|
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||||
|
self.hidden_size = vllm_config.speculative_config.\
|
||||||
|
draft_model_config.get_hidden_size(
|
||||||
|
)
|
||||||
|
self.dtype = vllm_config.model_config.dtype
|
||||||
|
|
||||||
|
def propose(
|
||||||
|
self,
|
||||||
|
target_hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Generate blocks and compute logits
|
||||||
|
blocks = self.model(target_hidden_states)
|
||||||
|
logits = self.model.compute_logits(blocks, None)
|
||||||
|
|
||||||
|
# Get draft tokens and transpose the result
|
||||||
|
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
|
||||||
|
return [list(row) for row in zip(*draft_tokens)]
|
||||||
|
|
||||||
|
def load_model(self, target_model: nn.Module) -> None:
|
||||||
|
# Get model loader and config
|
||||||
|
loader = get_model_loader(self.vllm_config.load_config)
|
||||||
|
draft_config = self.vllm_config.speculative_config.draft_model_config
|
||||||
|
|
||||||
|
# Load model with proper dtype and config
|
||||||
|
with set_default_torch_dtype(draft_config.dtype), \
|
||||||
|
set_current_vllm_config(self.vllm_config):
|
||||||
|
self.model = Medusa(
|
||||||
|
vllm_config=self.vllm_config.speculative_config).to(
|
||||||
|
self.device)
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
weights = loader.get_all_weights(draft_config, self.model)
|
||||||
|
self.model.load_weights(weights)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def dummy_run(self, num_tokens: int) -> None:
|
||||||
|
hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device)
|
||||||
|
with set_forward_context(None, self.vllm_config,
|
||||||
|
num_tokens=num_tokens):
|
||||||
|
self.model(hidden_states)
|
||||||
@@ -47,6 +47,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
|||||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||||
from vllm.v1.sample.sampler import Sampler
|
from vllm.v1.sample.sampler import Sampler
|
||||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||||
|
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||||
@@ -156,6 +157,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.device) # type: ignore
|
self.device) # type: ignore
|
||||||
if self.speculative_config.method == "eagle3":
|
if self.speculative_config.method == "eagle3":
|
||||||
self.use_aux_hidden_state_outputs = True
|
self.use_aux_hidden_state_outputs = True
|
||||||
|
elif self.speculative_config.method == "medusa":
|
||||||
|
self.drafter = MedusaProposer(
|
||||||
|
vllm_config=self.vllm_config,
|
||||||
|
device=self.device) # type: ignore
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown speculative decoding method: "
|
raise ValueError("Unknown speculative decoding method: "
|
||||||
f"{self.speculative_config.method}")
|
f"{self.speculative_config.method}")
|
||||||
@@ -1254,6 +1259,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
assert isinstance(self.drafter, NgramProposer)
|
assert isinstance(self.drafter, NgramProposer)
|
||||||
spec_token_ids = self.generate_draft_token_ids(
|
spec_token_ids = self.generate_draft_token_ids(
|
||||||
valid_sampled_token_ids, sampling_metadata)
|
valid_sampled_token_ids, sampling_metadata)
|
||||||
|
elif self.speculative_config.method == "medusa":
|
||||||
|
assert isinstance(self.drafter, MedusaProposer)
|
||||||
|
if max_gen_len == 1:
|
||||||
|
hidden_states = sample_hidden_states
|
||||||
|
else:
|
||||||
|
indices = []
|
||||||
|
offset = 0
|
||||||
|
for num_draft, tokens in zip(
|
||||||
|
spec_decode_metadata.num_draft_tokens,
|
||||||
|
valid_sampled_token_ids):
|
||||||
|
indices.append(offset + len(tokens) - 1)
|
||||||
|
offset += num_draft + 1
|
||||||
|
|
||||||
|
indices = torch.tensor(indices,
|
||||||
|
device=sample_hidden_states.device)
|
||||||
|
hidden_states = sample_hidden_states[indices]
|
||||||
|
|
||||||
|
spec_token_ids = self.drafter.propose(
|
||||||
|
target_hidden_states=hidden_states,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
elif self.speculative_config.use_eagle():
|
elif self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
# TODO(woosuk): Refactor the loop.
|
# TODO(woosuk): Refactor the loop.
|
||||||
|
|||||||
Reference in New Issue
Block a user