[Spec Decode] Update extract_hidden_states to use deferred kv_connector clear (#37013)
This commit is contained in:
committed by
GitHub
parent
43a73f853b
commit
04bf5a35fa
@@ -286,7 +286,9 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1):
|
||||
cached_req = self._active_requests[req_id]
|
||||
req_block_ids = self._req_blocks[req_id]
|
||||
|
||||
assert new_block_ids is not None
|
||||
if new_block_ids is None:
|
||||
continue
|
||||
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
req_block_ids.extend(block_ids)
|
||||
|
||||
@@ -3,26 +3,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer import has_kv_transfer_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
@@ -79,11 +74,10 @@ class ExtractHiddenStatesProposer:
|
||||
sampled_token_ids: torch.Tensor,
|
||||
target_hidden_states: list[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
scheduler_output: SchedulerOutput,
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
| list[dict[str, torch.Tensor]]
|
||||
| None = None,
|
||||
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
|
||||
) -> torch.Tensor:
|
||||
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
|
||||
|
||||
The ExtractHiddenStatesModel caches the hidden states in the KV cache
|
||||
@@ -99,7 +93,6 @@ class ExtractHiddenStatesProposer:
|
||||
target_hidden_states: List of hidden state tensors from target model
|
||||
(one per aux hidden state layer)
|
||||
common_attn_metadata: Attention metadata
|
||||
scheduler_output: Scheduler output for KV connector
|
||||
slot_mappings: Slot mappings for KV cache (unused, provided for
|
||||
interface compatibility)
|
||||
|
||||
@@ -136,22 +129,15 @@ class ExtractHiddenStatesProposer:
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[self.dp_rank] = num_input_tokens
|
||||
|
||||
with (
|
||||
set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
num_input_tokens, common_attn_metadata.slot_mapping
|
||||
),
|
||||
(
|
||||
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
|
||||
if has_kv_transfer_group()
|
||||
else nullcontext()
|
||||
) as kv_connector_output,
|
||||
):
|
||||
self.model(
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
@@ -159,7 +145,7 @@ class ExtractHiddenStatesProposer:
|
||||
|
||||
# Return the sampled tokens as "draft" tokens
|
||||
# Shape: [batch_size, 1] to match num_speculative_tokens=1
|
||||
return sampled_token_ids.unsqueeze(-1), kv_connector_output
|
||||
return sampled_token_ids
|
||||
|
||||
def _get_slot_mapping(
|
||||
self,
|
||||
|
||||
@@ -4328,23 +4328,12 @@ class GPUModelRunner(
|
||||
)
|
||||
target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
|
||||
|
||||
draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
|
||||
draft_token_ids = self.drafter.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
target_hidden_states=target_hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
scheduler_output=scheduler_output,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
# Combine KVConnectorOutputs or select the non-empty one
|
||||
if self.kv_connector_output and drafter_kv_connector_output:
|
||||
self.kv_connector_output = KVConnectorOutput.merge(
|
||||
self.kv_connector_output, drafter_kv_connector_output
|
||||
)
|
||||
else:
|
||||
self.kv_connector_output = (
|
||||
self.kv_connector_output or drafter_kv_connector_output
|
||||
)
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
|
||||
Reference in New Issue
Block a user