diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index eb559845f..380836bb3 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -145,7 +145,10 @@ class ExtractHiddenStatesProposer: # Return the sampled tokens as "draft" tokens # Shape: [batch_size, 1] to match num_speculative_tokens=1 - return sampled_token_ids + # On decode steps with spec tokens, sampled_token_ids may have + # shape [batch_size, 2] (target + spec verification); slice to + # return only the target-sampled column. + return sampled_token_ids[:, :1] def _get_slot_mapping( self,