[Model Runner V2] Implement multi-step Eagle with CUDA graph (#29559)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-27 00:09:41 -08:00
committed by GitHub
parent 43c5792592
commit da3222f371
4 changed files with 526 additions and 70 deletions

View File

@@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
vllm_config=self.vllm_config,
device=self.device,
)
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
def get_supported_tasks(self) -> tuple[str]:
return ("generate",)
@@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
self.device,
)
if self.do_spec_decode:
# HACK(woosuk)
self.speculator.set_attn(
self.kv_cache_config,
self.attn_metadata_builders,
self.block_tables,
)
# TODO(woosuk): Support other backends.
if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
@@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata)
@torch.inference_mode()
def _dummy_speculator_run(
self,
hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
) -> None:
num_tokens = hidden_states.shape[0]
num_reqs = min(num_tokens, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens,
input_buffers=self.input_buffers,
device=self.device,
)
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
self.propose_draft(
input_batch=input_batch,
sampling_metadata=sampling_metadata,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=num_sampled,
num_rejected=num_rejected,
)
@torch.inference_mode()
def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run(
@@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode:
self._dummy_speculator_run(hidden_states, None)
num_tokens_across_dp = make_num_tokens_across_dp(
self.dp_size, self.max_num_tokens
)
self.speculator.run_model(
self.max_num_tokens,
attn_metadata=None,
num_tokens_across_dp=num_tokens_across_dp,
)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
@@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
)
if self.do_spec_decode:
self.speculator.capture_model()
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]