[Model] vLLM v1 supports Medusa (#17956)

Signed-off-by: lisiqi23 <lisiqi23@xiaomi.com>
Signed-off-by: skylee-01 <497627264@qq.com>
Co-authored-by: lisiqi23 <lisiqi23@xiaomi.com>
This commit is contained in:
Sky Lee
2025-05-16 12:05:31 +08:00
committed by GitHub
parent ee659e3b60
commit f4937a51c1
4 changed files with 108 additions and 2 deletions

View File

@@ -47,6 +47,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
@@ -156,6 +157,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.device) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config,
device=self.device) # type: ignore
else:
raise ValueError("Unknown speculative decoding method: "
f"{self.speculative_config.method}")
@@ -1254,6 +1259,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
if max_gen_len == 1:
hidden_states = sample_hidden_states
else:
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
valid_sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices,
device=sample_hidden_states.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.