[Spec Decode] Update extract_hidden_states to use deferred kv_connector clear (#37013)
This commit is contained in:
committed by
GitHub
parent
43a73f853b
commit
04bf5a35fa
@@ -252,29 +252,22 @@ def test_propose():
|
||||
]
|
||||
|
||||
# Sampled token IDs from target model
|
||||
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
|
||||
|
||||
# Mock scheduler output
|
||||
mock_scheduler_output = mock.MagicMock()
|
||||
sampled_token_ids = torch.tensor(
|
||||
[42, 60], dtype=torch.int32, device=device
|
||||
).unsqueeze(-1)
|
||||
|
||||
# Call propose
|
||||
with mock.patch(
|
||||
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
|
||||
) as mock_has_kv:
|
||||
mock_has_kv.return_value = False
|
||||
|
||||
draft_tokens, kv_connector_output = proposer.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
target_hidden_states=target_hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
scheduler_output=mock_scheduler_output,
|
||||
slot_mappings=None,
|
||||
)
|
||||
draft_tokens = proposer.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
target_hidden_states=target_hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
slot_mappings=None,
|
||||
)
|
||||
|
||||
# Verify draft tokens match sampled tokens
|
||||
# Shape should be [batch_size, 1] for num_speculative_tokens=1
|
||||
assert draft_tokens.shape == (batch_size, 1)
|
||||
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)
|
||||
assert torch.equal(draft_tokens, sampled_token_ids)
|
||||
|
||||
# Verify the model was called
|
||||
model_mock.assert_called_once()
|
||||
@@ -326,21 +319,16 @@ def test_propose_different_layer_counts(num_hidden_layers):
|
||||
for _ in range(num_hidden_layers)
|
||||
]
|
||||
|
||||
sampled_token_ids = torch.tensor([42, 60], dtype=torch.int32, device=device)
|
||||
mock_scheduler_output = mock.MagicMock()
|
||||
sampled_token_ids = torch.tensor(
|
||||
[42, 60], dtype=torch.int32, device=device
|
||||
).unsqueeze(-1)
|
||||
|
||||
with mock.patch(
|
||||
"vllm.v1.spec_decode.extract_hidden_states.has_kv_transfer_group"
|
||||
) as mock_has_kv:
|
||||
mock_has_kv.return_value = False
|
||||
|
||||
draft_tokens, _ = proposer.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
target_hidden_states=target_hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
scheduler_output=mock_scheduler_output,
|
||||
slot_mappings=None,
|
||||
)
|
||||
draft_tokens = proposer.propose(
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
target_hidden_states=target_hidden_states,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
slot_mappings=None,
|
||||
)
|
||||
|
||||
assert draft_tokens.shape == (batch_size, 1)
|
||||
assert torch.equal(draft_tokens[:, 0], sampled_token_ids)
|
||||
assert torch.equal(draft_tokens, sampled_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user