[Model Runner V2] Implement multi-step Eagle with CUDA graph (#29559)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
)
|
||||
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
return ("generate",)
|
||||
@@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
self.kv_cache_config,
|
||||
self.attn_metadata_builders,
|
||||
self.block_tables,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Support other backends.
|
||||
if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
|
||||
raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
|
||||
@@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
self.sampler(logits, sampling_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_speculator_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
) -> None:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
sampling_metadata = SamplingMetadata.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
device=self.device,
|
||||
)
|
||||
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
|
||||
num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
|
||||
self.propose_draft(
|
||||
input_batch=input_batch,
|
||||
sampling_metadata=sampling_metadata,
|
||||
last_hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
num_sampled=num_sampled,
|
||||
num_rejected=num_rejected,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
@@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
if self.do_spec_decode:
|
||||
self._dummy_speculator_run(hidden_states, None)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.dp_size, self.max_num_tokens
|
||||
)
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
@@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
self.speculator.capture_model()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
Reference in New Issue
Block a user