[Bugfix][Async] Fix async spec decoding with hybrid models (#38556)

Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com>
(cherry picked from commit 757068dc65)
This commit is contained in:
Matthew Bonanni
2026-03-31 11:08:54 -04:00
committed by khluu
parent bcc0fdd0f3
commit 268bed9cf3
6 changed files with 177 additions and 36 deletions

View File

@@ -3,6 +3,7 @@
from unittest import mock
import numpy as np
import pytest
import torch
@@ -132,16 +133,12 @@ def test_prepare_next_token_ids_padded():
device = torch.device(current_platform.device_type)
num_requests = 4
batch_spec = BatchSpec(
seq_lens=[5] * num_requests,
query_lens=[5] * num_requests,
)
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_input_batch.num_tokens_no_spec = np.array([5] * num_requests)
mock_requests = {}
for req_id in req_ids:
@@ -174,12 +171,6 @@ def test_prepare_next_token_ids_padded():
proposer = _create_proposer(num_speculative_tokens=1)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
# It doesn't depend on whether the request is discarded
expected_valid_sampled_tokens_count = torch.tensor(
@@ -187,7 +178,6 @@ def test_prepare_next_token_ids_padded():
)
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata.seq_lens_cpu,
sampled_token_ids,
mock_requests,
mock_input_batch,