[Bugfix][Async] Fix async spec decoding with hybrid models (#38556)

Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com>
(cherry picked from commit 757068dc65)
This commit is contained in:
Matthew Bonanni
2026-03-31 11:08:54 -04:00
committed by khluu
parent bcc0fdd0f3
commit 268bed9cf3
6 changed files with 177 additions and 36 deletions

View File

@@ -3,6 +3,7 @@
from unittest import mock
import numpy as np
import pytest
import torch
@@ -111,16 +112,14 @@ def test_prepare_next_token_ids():
num_requests = 4
num_speculative_tokens = 4
batch_spec = BatchSpec(
seq_lens=[num_speculative_tokens + 1] * num_requests,
query_lens=[num_speculative_tokens + 1] * num_requests,
)
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_input_batch.num_tokens_no_spec = np.array(
[num_speculative_tokens + 1] * num_requests
)
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
mock_requests = {}
@@ -165,19 +164,12 @@ def test_prepare_next_token_ids():
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=BLOCK_SIZE,
device=device,
)
expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)
next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata.seq_lens_cpu,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,