diff --git a/tests/v1/spec_decode/test_extract_hidden_states.py b/tests/v1/spec_decode/test_extract_hidden_states.py index af911e91d..6f0ac8cae 100644 --- a/tests/v1/spec_decode/test_extract_hidden_states.py +++ b/tests/v1/spec_decode/test_extract_hidden_states.py @@ -252,29 +252,22 @@ def test_propose(): ] # Sampled token IDs from target model - sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device) - - # Mock scheduler output - mock_scheduler_output = mock.MagicMock() + sampled_token_ids = torch.tensor( + [42, 60], dtype=torch.int32, device=device + ).unsqueeze(-1) # Call propose - with mock.patch( - "vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group" - ) as mock_has_kv: - mock_has_kv.return_value = False - - draft_tokens, kv_connector_output = proposer.propose( - sampled_token_ids=sampled_token_ids, - target_hidden_states=target_hidden_states, - common_attn_metadata=common_attn_metadata, - scheduler_output=mock_scheduler_output, - slot_mappings=None, - ) + draft_tokens = proposer.propose( + sampled_token_ids=sampled_token_ids, + target_hidden_states=target_hidden_states, + common_attn_metadata=common_attn_metadata, + slot_mappings=None, + ) # Verify draft tokens match sampled tokens # Shape should be [batch_size, 1] for num_speculative_tokens=1 assert draft_tokens.shape == (batch_size, 1) - assert torch.equal(draft_tokens[:, 0], sampled_token_ids) + assert torch.equal(draft_tokens, sampled_token_ids) # Verify the model was called model_mock.assert_called_once() @@ -326,21 +319,16 @@ def test_propose_different_layer_counts(num_hidden_layers): for _ in range(num_hidden_layers) ] - sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device) - mock_scheduler_output = mock.MagicMock() + sampled_token_ids = torch.tensor( + [42, 60], dtype=torch.int32, device=device + ).unsqueeze(-1) - with mock.patch( - "vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group" - ) as mock_has_kv: - mock_has_kv.return_value = False - - draft_tokens, _ = proposer.propose( - sampled_token_ids=sampled_token_ids, - target_hidden_states=target_hidden_states, - common_attn_metadata=common_attn_metadata, - scheduler_output=mock_scheduler_output, - slot_mappings=None, - ) + draft_tokens = proposer.propose( + sampled_token_ids=sampled_token_ids, + target_hidden_states=target_hidden_states, + common_attn_metadata=common_attn_metadata, + slot_mappings=None, + ) assert draft_tokens.shape == (batch_size, 1) - assert torch.equal(draft_tokens[:, 0], sampled_token_ids) + assert torch.equal(draft_tokens, sampled_token_ids) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py index 945f8d9fd..fcd1f365a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py @@ -286,7 +286,9 @@ class ExampleHiddenStatesConnector(KVConnectorBase_V1): cached_req = self._active_requests[req_id] req_block_ids = self._req_blocks[req_id] - assert new_block_ids is not None + if new_block_ids is None: + continue + block_ids = new_block_ids[0] req_block_ids.extend(block_ids) diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index 38a54f016..dd4e47d45 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -3,26 +3,21 @@ from __future__ import annotations -from contextlib import nullcontext from typing import TYPE_CHECKING import torch import torch.nn as nn from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config -from vllm.distributed.kv_transfer import has_kv_transfer_group from vllm.forward_context import set_forward_context from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -from vllm.v1.outputs import KVConnectorOutput from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig PADDING_SLOT_ID = -1 @@ -79,11 +74,10 @@ class ExtractHiddenStatesProposer: sampled_token_ids: torch.Tensor, target_hidden_states: list[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, - scheduler_output: SchedulerOutput, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, - ) -> tuple[torch.Tensor, KVConnectorOutput | None]: + ) -> torch.Tensor: """Propose draft tokens by calling the ExtractHiddenStatesModel model. The ExtractHiddenStatesModel caches the hidden states in the KV cache @@ -99,7 +93,6 @@ class ExtractHiddenStatesProposer: target_hidden_states: List of hidden state tensors from target model (one per aux hidden state layer) common_attn_metadata: Attention metadata - scheduler_output: Scheduler output for KV connector slot_mappings: Slot mappings for KV cache (unused, provided for interface compatibility) @@ -136,22 +129,15 @@ class ExtractHiddenStatesProposer: if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens - with ( - set_forward_context( - per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_input_tokens, common_attn_metadata.slot_mapping - ), + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=self._get_slot_mapping( + num_input_tokens, common_attn_metadata.slot_mapping ), - ( - KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) - if has_kv_transfer_group() - else nullcontext() - ) as kv_connector_output, ): self.model( hidden_states=self.hidden_states[:num_input_tokens], @@ -159,7 +145,7 @@ class ExtractHiddenStatesProposer: # Return the sampled tokens as "draft" tokens # Shape: [batch_size, 1] to match num_speculative_tokens=1 - return sampled_token_ids.unsqueeze(-1), kv_connector_output + return sampled_token_ids def _get_slot_mapping( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index da41fe6a3..98e1dab36 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4328,23 +4328,12 @@ class GPUModelRunner( ) target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] - draft_token_ids, drafter_kv_connector_output = self.drafter.propose( + draft_token_ids = self.drafter.propose( sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, - scheduler_output=scheduler_output, slot_mappings=slot_mappings, ) - # Combine KVConnectorOutputs or select the non-empty one - if self.kv_connector_output and drafter_kv_connector_output: - self.kv_connector_output = KVConnectorOutput.merge( - self.kv_connector_output, drafter_kv_connector_output - ) - else: - self.kv_connector_output = ( - self.kv_connector_output or drafter_kv_connector_output - ) - next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( common_attn_metadata,