[Perf] Optimize EAGLE prepare_inputs_padded with triton kernels (#28597)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Benjamin Chislett
2025-11-28 17:25:05 -05:00
committed by GitHub
parent 3461e7efd8
commit 1986de1375
4 changed files with 199 additions and 108 deletions

View File

@@ -103,16 +103,23 @@ def test_prepare_next_token_ids():
mock_request.num_computed_tokens = 0
mock_requests[req_id] = mock_request
# explicitly discard the last request
discarded_req_mask = torch.tensor(
[False, False, False, True], dtype=torch.bool, device=device
)
sampled_token_ids = [
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
[0, 1, 2, 3, 4], # all accepted, "4" sampled
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
[-1, -1, -1, -1, -1], # this request will be discarded
[0, 1, 2, -1, -1], # explicitly discarded, sampling should be ignored
]
sampled_token_ids_tensor = torch.tensor(
sampled_token_ids, dtype=torch.int32, device=device
)
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
for i in range(len(sampled_token_ids_cpu)):
if discarded_req_mask[i]:
sampled_token_ids_cpu[i] = []
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(
@@ -136,9 +143,6 @@ def test_prepare_next_token_ids():
device=device,
)
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
num_discarded_reqs = 1
expected_valid_sampled_tokens_count = torch.tensor(
[2, 5, 0, 0], dtype=torch.int32, device=device
)
@@ -149,8 +153,7 @@ def test_prepare_next_token_ids():
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,
discarded_req_indices,
num_discarded_reqs,
discarded_req_mask,
)
)
@@ -256,11 +259,6 @@ def test_prepare_inputs_padded():
- Request 3: query_len = 3, rejected = 2
Expected outputs:
token_indices: [0, 1, 2,
3, 4, 5,
6, 7, 8]
Reason: Deferred computation should not disturb the original indices.
token_indices_to_sample: [1, 5, 6]
Reason: After accounting for rejections, these are the valid token positions
from the original indices to sample from.
@@ -268,9 +266,6 @@ def test_prepare_inputs_padded():
device = torch.device(current_platform.device_type)
expected_token_indices = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device
)
expected_token_indices_to_sample = torch.tensor(
[1, 5, 6], dtype=torch.int32, device=device
)
@@ -305,15 +300,12 @@ def test_prepare_inputs_padded():
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices, token_indices_to_sample = (
proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
assert torch.equal(token_indices, expected_token_indices)
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)