[Model Runner V2] Implement Single-step Eagle 1 (#29300)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -45,6 +45,7 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||
@@ -97,16 +98,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.use_async_scheduling:
|
||||
self.input_prep_event = torch.cuda.Event()
|
||||
self.structured_outputs_event = torch.cuda.Event()
|
||||
self.spec_decode_event = torch.cuda.Event()
|
||||
else:
|
||||
self.input_prep_event = None
|
||||
self.structured_outputs_event = None
|
||||
self.spec_decode_event = None
|
||||
|
||||
if self.speculative_config is not None:
|
||||
self.do_spec_decode = True
|
||||
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
||||
self.speculator = init_speculator(self.vllm_config, self.device)
|
||||
else:
|
||||
self.do_spec_decode = False
|
||||
self.num_speculative_steps = 0
|
||||
self.speculator = None
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
@@ -153,6 +158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
self.speculator.load_model(self.model)
|
||||
time_after_load = time.perf_counter()
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
@@ -285,6 +292,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
self.sampler(logits, sampling_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_speculator_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
) -> None:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
sampling_metadata = SamplingMetadata.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
device=self.device,
|
||||
)
|
||||
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
|
||||
self.propose_draft(
|
||||
input_batch=input_batch,
|
||||
sampling_metadata=sampling_metadata,
|
||||
last_hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
num_sampled=num_sampled,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
@@ -292,6 +326,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
skip_attn=True,
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
if self.do_spec_decode:
|
||||
self._dummy_speculator_run(hidden_states, None)
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
@@ -727,6 +763,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.req_states.prefill_len.np[idx_mapping_np],
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def propose_draft(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
last_hidden_states: torch.Tensor,
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
num_sampled: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = input_batch.num_reqs
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
with async_barrier(self.spec_decode_event):
|
||||
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
|
||||
self.req_states.prefill_token_ids[
|
||||
idx_mapping_np,
|
||||
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
|
||||
]
|
||||
)
|
||||
next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
|
||||
num_reqs
|
||||
)
|
||||
|
||||
assert self.speculator is not None
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
sampling_metadata,
|
||||
last_hidden_states,
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
self.req_states.last_sampled_tokens,
|
||||
next_prefill_tokens,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
return draft_tokens
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@@ -913,6 +984,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.postprocess(
|
||||
input_batch, sampler_output.sampled_token_ids, num_sampled_tokens
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
_ = self.propose_draft(
|
||||
input_batch,
|
||||
sampling_metadata,
|
||||
hidden_states,
|
||||
None, # aux_hidden_states
|
||||
num_sampled_tokens,
|
||||
)
|
||||
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
|
||||
Reference in New Issue
Block a user