[Spec Decode] Update extract_hidden_states to use deferred kv_connector clear (#37013)

This commit is contained in:
Fynn Schmitt-Ulms
2026-03-16 09:53:45 -04:00
committed by GitHub
parent 43a73f853b
commit 04bf5a35fa
4 changed files with 34 additions and 69 deletions

View File

@@ -286,7 +286,9 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1):
cached_req = self._active_requests[req_id]
req_block_ids = self._req_blocks[req_id]
assert new_block_ids is not None
if new_block_ids is None:
continue
block_ids = new_block_ids[0]
req_block_ids.extend(block_ids)

View File

@@ -3,26 +3,21 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
PADDING_SLOT_ID = -1
@@ -79,11 +74,10 @@ class ExtractHiddenStatesProposer:
sampled_token_ids: torch.Tensor,
target_hidden_states: list[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
) -> torch.Tensor:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
@@ -99,7 +93,6 @@ class ExtractHiddenStatesProposer:
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
@@ -136,22 +129,15 @@ class ExtractHiddenStatesProposer:
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
with (
set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
(
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
) as kv_connector_output,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
@@ -159,7 +145,7 @@ class ExtractHiddenStatesProposer:
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids.unsqueeze(-1), kv_connector_output
return sampled_token_ids
def _get_slot_mapping(
self,

View File

@@ -4328,23 +4328,12 @@ class GPUModelRunner(
)
target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
draft_token_ids = self.drafter.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=scheduler_output,
slot_mappings=slot_mappings,
)
# Combine KVConnectorOutputs or select the non-empty one
if self.kv_connector_output and drafter_kv_connector_output:
self.kv_connector_output = KVConnectorOutput.merge(
self.kv_connector_output, drafter_kv_connector_output
)
else:
self.kv_connector_output = (
self.kv_connector_output or drafter_kv_connector_output
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,