[Spec Decode] Efficient padded speculation (#24539)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
5c65a72bb1
commit
b7433ca1a4
@@ -19,6 +19,8 @@ from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
@@ -64,6 +66,86 @@ def _create_proposer(
|
||||
device=current_platform.device_type)
|
||||
|
||||
|
||||
def test_prepare_next_token_ids():
|
||||
"""
|
||||
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
|
||||
Each will produce a device tensor of next_token_ids, taking as input
|
||||
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
|
||||
or the CPU python list[list[int]] with the rejected tokens removed.
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_requests = 4
|
||||
num_speculative_tokens = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
query_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i+1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
|
||||
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
|
||||
mock_requests = {}
|
||||
for req_id in req_ids:
|
||||
mock_request = mock.MagicMock(spec=CachedRequestState)
|
||||
# Each request will have a backup next token id of 10, 20, 30, 40
|
||||
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
|
||||
mock_request.num_computed_tokens = 0
|
||||
mock_requests[req_id] = mock_request
|
||||
|
||||
sampled_token_ids = [
|
||||
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
|
||||
[0, 1, 2, 3, 4], # all accepted, "4" sampled
|
||||
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
|
||||
[-1, -1, -1, -1, -1] # this request will be discarded
|
||||
]
|
||||
sampled_token_ids_tensor = torch.tensor(sampled_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
sampled_token_ids_cpu = [[i for i in seq if i != -1]
|
||||
for seq in sampled_token_ids]
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids_cpu, mock_requests, mock_input_batch,
|
||||
mock_num_scheduled_tokens)
|
||||
|
||||
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
|
||||
num_discarded_reqs = 1
|
||||
|
||||
expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
next_token_ids_from_padded, valid_sampled_tokens_count = \
|
||||
proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata, sampled_token_ids_tensor, mock_requests,
|
||||
mock_input_batch, discarded_req_indices, num_discarded_reqs)
|
||||
|
||||
assert torch.equal(next_token_ids_from_padded,
|
||||
expected_next_token_ids_tensor)
|
||||
assert torch.equal(valid_sampled_tokens_count,
|
||||
expected_valid_sampled_tokens_count)
|
||||
|
||||
|
||||
def test_prepare_inputs():
|
||||
"""
|
||||
cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
@@ -90,10 +172,24 @@ def test_prepare_inputs():
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Rejected tokens per request: [1, 3, 2]
|
||||
num_rejected_tokens = torch.tensor([1, 3, 2],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
|
||||
# from the previous iteration, and the last token is the bonus token sampled
|
||||
# from the base model.
|
||||
num_draft_tokens = [3, 6, 4] # one less than query_lens
|
||||
# num rejected tokens is [1, 3, 2]
|
||||
ACCEPT_TOKEN = 0
|
||||
BONUS_TOKEN = 1
|
||||
REJECT_TOKEN = -1
|
||||
sampled_token_ids = [
|
||||
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
|
||||
[
|
||||
ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
|
||||
REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
|
||||
],
|
||||
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
|
||||
]
|
||||
sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
|
||||
for seq in sampled_token_ids]
|
||||
|
||||
# Expected calculations:
|
||||
# query_len_per_req = [4, 7, 5]
|
||||
@@ -125,7 +221,7 @@ def test_prepare_inputs():
|
||||
proposer = _create_proposer("eagle", 1)
|
||||
|
||||
updated_metadata, token_indices = proposer.prepare_inputs(
|
||||
common_attn_metadata, num_rejected_tokens.cpu())
|
||||
common_attn_metadata, sampled_token_ids, num_draft_tokens)
|
||||
|
||||
assert torch.equal(updated_metadata.query_start_loc,
|
||||
expected_cu_num_tokens)
|
||||
@@ -133,6 +229,77 @@ def test_prepare_inputs():
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
|
||||
|
||||
def test_prepare_inputs_padded():
|
||||
"""
|
||||
Input scenario is 3 requests with num_speculative_tokens == 2 and:
|
||||
- Request 1: query_len = 3, rejected = 1
|
||||
- Request 2: query_len = 3, rejected = 0
|
||||
- Request 3: query_len = 3, rejected = 2
|
||||
|
||||
Expected outputs:
|
||||
token_indices: [0, 1, 2,
|
||||
3, 4, 5,
|
||||
6, 7, 8]
|
||||
Reason: Deferred computation should not disturb the original indices.
|
||||
|
||||
token_indices_to_sample: [1, 5, 6]
|
||||
Reason: After accounting for rejections, these are the valid token positions
|
||||
from the original indices to sample from.
|
||||
"""
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
expected_token_indices_to_sample = torch.tensor([1, 5, 6],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
num_speculative_tokens = 2
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[3, 3, 3],
|
||||
query_lens=[3, 3, 3],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
|
||||
expected_query_start_loc = torch.tensor([0, 3, 6, 9],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids=[[0] * num_speculative_tokens] * 3,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# num_rejected_tokens = [1, 0, 2]
|
||||
# num_draft_tokens = [2, 2, 2]
|
||||
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
|
||||
valid_sampled_tokens_count = torch.tensor([2, 3, 1],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
output_metadata, token_indices, token_indices_to_sample = \
|
||||
proposer.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
assert output_metadata.max_query_len == 3
|
||||
assert torch.equal(output_metadata.query_start_loc,
|
||||
expected_query_start_loc)
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
assert torch.equal(token_indices_to_sample,
|
||||
expected_token_indices_to_sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@@ -373,6 +540,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
@@ -526,6 +694,7 @@ def test_propose_tree(spec_token_tree):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
|
||||
Reference in New Issue
Block a user