[Perf] Optimize EAGLE prepare_inputs_padded with triton kernels (#28597)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
committed by
GitHub
parent
3461e7efd8
commit
1986de1375
@@ -103,16 +103,23 @@ def test_prepare_next_token_ids():
|
||||
mock_request.num_computed_tokens = 0
|
||||
mock_requests[req_id] = mock_request
|
||||
|
||||
# explicitly discard the last request
|
||||
discarded_req_mask = torch.tensor(
|
||||
[False, False, False, True], dtype=torch.bool, device=device
|
||||
)
|
||||
sampled_token_ids = [
|
||||
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
|
||||
[0, 1, 2, 3, 4], # all accepted, "4" sampled
|
||||
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
|
||||
[-1, -1, -1, -1, -1], # this request will be discarded
|
||||
[0, 1, 2, -1, -1], # explicitly discarded, sampling should be ignored
|
||||
]
|
||||
sampled_token_ids_tensor = torch.tensor(
|
||||
sampled_token_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
|
||||
for i in range(len(sampled_token_ids_cpu)):
|
||||
if discarded_req_mask[i]:
|
||||
sampled_token_ids_cpu[i] = []
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(
|
||||
@@ -136,9 +143,6 @@ def test_prepare_next_token_ids():
|
||||
device=device,
|
||||
)
|
||||
|
||||
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
|
||||
num_discarded_reqs = 1
|
||||
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
[2, 5, 0, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
@@ -149,8 +153,7 @@ def test_prepare_next_token_ids():
|
||||
sampled_token_ids_tensor,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
discarded_req_indices,
|
||||
num_discarded_reqs,
|
||||
discarded_req_mask,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -256,11 +259,6 @@ def test_prepare_inputs_padded():
|
||||
- Request 3: query_len = 3, rejected = 2
|
||||
|
||||
Expected outputs:
|
||||
token_indices: [0, 1, 2,
|
||||
3, 4, 5,
|
||||
6, 7, 8]
|
||||
Reason: Deferred computation should not disturb the original indices.
|
||||
|
||||
token_indices_to_sample: [1, 5, 6]
|
||||
Reason: After accounting for rejections, these are the valid token positions
|
||||
from the original indices to sample from.
|
||||
@@ -268,9 +266,6 @@ def test_prepare_inputs_padded():
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
expected_token_indices = torch.tensor(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device
|
||||
)
|
||||
expected_token_indices_to_sample = torch.tensor(
|
||||
[1, 5, 6], dtype=torch.int32, device=device
|
||||
)
|
||||
@@ -305,15 +300,12 @@ def test_prepare_inputs_padded():
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
output_metadata, token_indices, token_indices_to_sample = (
|
||||
proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
assert output_metadata.max_query_len == 3
|
||||
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user