[Bugfix][Async] Fix async spec decoding with hybrid models (#38556)
Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com>
(cherry picked from commit 757068dc65)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -132,16 +133,12 @@ def test_prepare_next_token_ids_padded():
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_requests = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[5] * num_requests,
|
||||
query_lens=[5] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
mock_input_batch.num_tokens_no_spec = np.array([5] * num_requests)
|
||||
|
||||
mock_requests = {}
|
||||
for req_id in req_ids:
|
||||
@@ -174,12 +171,6 @@ def test_prepare_next_token_ids_padded():
|
||||
|
||||
proposer = _create_proposer(num_speculative_tokens=1)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# valid_sampled_tokens_count tracks if token is valid (not -1 and in vocab range)
|
||||
# It doesn't depend on whether the request is discarded
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
@@ -187,7 +178,6 @@ def test_prepare_next_token_ids_padded():
|
||||
)
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
|
||||
Reference in New Issue
Block a user