[Spec Decode] Update extract_hidden_states to use deferred kv_connector clear (#37013)

This commit is contained in:
Fynn Schmitt-Ulms
2026-03-16 09:53:45 -04:00
committed by GitHub
parent 43a73f853b
commit 04bf5a35fa
4 changed files with 34 additions and 69 deletions

View File

@@ -252,29 +252,22 @@ def test_propose():
]
# Sampled token IDs from target model
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
# Mock scheduler output
mock_scheduler_output = mock.MagicMock()
sampled_token_ids = torch.tensor(
[42, 60], dtype=torch.int32, device=device
).unsqueeze(-1)
# Call propose
with mock.patch(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
) as mock_has_kv:
mock_has_kv.return_value = False
draft_tokens, kv_connector_output = proposer.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=mock_scheduler_output,
slot_mappings=None,
)
draft_tokens = proposer.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
slot_mappings=None,
)
# Verify draft tokens match sampled tokens
# Shape should be [batch_size, 1] for num_speculative_tokens=1
assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)
assert torch.equal(draft_tokens, sampled_token_ids)
# Verify the model was called
model_mock.assert_called_once()
@@ -326,21 +319,16 @@ def test_propose_different_layer_counts(num_hidden_layers):
for _ in range(num_hidden_layers)
]
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
mock_scheduler_output = mock.MagicMock()
sampled_token_ids = torch.tensor(
[42, 60], dtype=torch.int32, device=device
).unsqueeze(-1)
with mock.patch(
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
) as mock_has_kv:
mock_has_kv.return_value = False
draft_tokens, _ = proposer.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=mock_scheduler_output,
slot_mappings=None,
)
draft_tokens = proposer.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
slot_mappings=None,
)
assert draft_tokens.shape == (batch_size, 1)
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)
assert torch.equal(draft_tokens, sampled_token_ids)