[Model] vLLM v1 supports Medusa (#17956)
Signed-off-by: lisiqi23 <lisiqi23@xiaomi.com> Signed-off-by: skylee-01 <497627264@qq.com> Co-authored-by: lisiqi23 <lisiqi23@xiaomi.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
@@ -156,6 +157,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.device) # type: ignore
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
elif self.speculative_config.method == "medusa":
|
||||
self.drafter = MedusaProposer(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device) # type: ignore
|
||||
else:
|
||||
raise ValueError("Unknown speculative decoding method: "
|
||||
f"{self.speculative_config.method}")
|
||||
@@ -1254,6 +1259,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self.generate_draft_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
if max_gen_len == 1:
|
||||
hidden_states = sample_hidden_states
|
||||
else:
|
||||
indices = []
|
||||
offset = 0
|
||||
for num_draft, tokens in zip(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
valid_sampled_token_ids):
|
||||
indices.append(offset + len(tokens) - 1)
|
||||
offset += num_draft + 1
|
||||
|
||||
indices = torch.tensor(indices,
|
||||
device=sample_hidden_states.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.drafter.propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
|
||||
Reference in New Issue
Block a user