[Misc] Tidy up some spec decode logic in GPUModelRunner (#31591)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -139,6 +139,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
return ("generate",)
|
||||
|
||||
|
||||
@@ -452,6 +452,11 @@ class GPUModelRunner(
|
||||
self.num_spec_tokens = 0
|
||||
if self.speculative_config:
|
||||
self.num_spec_tokens = self.speculative_config.num_speculative_tokens
|
||||
draft_config = self.speculative_config.draft_model_config
|
||||
if draft_config is not None and draft_config.max_model_len is not None:
|
||||
self.effective_drafter_max_model_len = draft_config.max_model_len
|
||||
else:
|
||||
self.effective_drafter_max_model_len = self.max_model_len
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
@@ -674,6 +679,13 @@ class GPUModelRunner(
|
||||
self.kv_connector_output: KVConnectorOutput | None = None
|
||||
self.layerwise_nvtx_hooks_registered = False
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
if self.speculative_config:
|
||||
draft_config = self.speculative_config.draft_model_config
|
||||
if draft_config is None or draft_config.max_model_len is None:
|
||||
self.effective_drafter_max_model_len = self.max_model_len
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
@@ -3399,54 +3411,41 @@ class GPUModelRunner(
|
||||
self._copy_draft_token_ids_to_cpu(scheduler_output)
|
||||
|
||||
spec_config = self.speculative_config
|
||||
use_padded_batch_for_eagle = (
|
||||
spec_config is not None
|
||||
and spec_config.use_eagle()
|
||||
and not spec_config.disable_padded_drafter_batch
|
||||
)
|
||||
effective_drafter_max_model_len = self.max_model_len
|
||||
if effective_drafter_max_model_len is None:
|
||||
effective_drafter_max_model_len = self.model_config.max_model_len
|
||||
if (
|
||||
spec_config is not None
|
||||
and spec_config.draft_model_config is not None
|
||||
and spec_config.draft_model_config.max_model_len is not None
|
||||
):
|
||||
effective_drafter_max_model_len = (
|
||||
spec_config.draft_model_config.max_model_len
|
||||
propose_drafts_after_bookkeeping = False
|
||||
if spec_config is not None:
|
||||
input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
|
||||
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
|
||||
<= self.effective_drafter_max_model_len
|
||||
)
|
||||
input_fits_in_drafter = spec_decode_common_attn_metadata and (
|
||||
spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
|
||||
<= effective_drafter_max_model_len
|
||||
)
|
||||
if use_padded_batch_for_eagle:
|
||||
assert self.speculative_config is not None
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
if input_fits_in_drafter:
|
||||
if spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampled_token_ids)
|
||||
elif self.valid_sampled_token_count_event is not None:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
spec_decode_common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_mask.gpu,
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
if input_fits_in_drafter:
|
||||
propose_draft_token_ids(sampled_token_ids)
|
||||
elif self.valid_sampled_token_count_event is not None:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
spec_decode_common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_mask.gpu,
|
||||
)
|
||||
)
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
# Since we couldn't run the drafter,
|
||||
# just use zeros for the draft tokens.
|
||||
self._draft_token_ids = torch.zeros(
|
||||
1, device=self.device, dtype=torch.int32
|
||||
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
|
||||
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
# Since we couldn't run the drafter,
|
||||
# just use zeros for the draft tokens.
|
||||
self._draft_token_ids = torch.zeros(
|
||||
1, device=self.device, dtype=torch.int32
|
||||
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
|
||||
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
|
||||
else:
|
||||
propose_drafts_after_bookkeeping = input_fits_in_drafter
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
|
||||
(
|
||||
@@ -3466,17 +3465,14 @@ class GPUModelRunner(
|
||||
spec_decode_metadata,
|
||||
)
|
||||
|
||||
if (
|
||||
self.speculative_config
|
||||
and not use_padded_batch_for_eagle
|
||||
and input_fits_in_drafter
|
||||
):
|
||||
if propose_drafts_after_bookkeeping:
|
||||
# ngram and other speculative decoding methods use the sampled
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: eplb"):
|
||||
self.eplb_step()
|
||||
|
||||
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=req_ids_output_copy,
|
||||
@@ -3494,6 +3490,7 @@ class GPUModelRunner(
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
return output
|
||||
|
||||
with record_function_or_nullcontext(
|
||||
"gpu_model_runner: AsyncGPUModelRunnerOutput"
|
||||
):
|
||||
|
||||
@@ -390,7 +390,7 @@ class Worker(WorkerBase):
|
||||
"""
|
||||
self.model_config.max_model_len = max_model_len
|
||||
if self.model_runner is not None:
|
||||
self.model_runner.max_model_len = max_model_len
|
||||
self.model_runner.update_max_model_len(max_model_len)
|
||||
logger.debug("Updated max_model_len to %d", max_model_len)
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
|
||||
Reference in New Issue
Block a user