[V1][Spec Decode] Eagle unit tests (#17350)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
wwl2755
2025-05-12 16:01:17 -07:00
committed by GitHub
parent ebab1ac37c
commit dc9905368d
2 changed files with 344 additions and 0 deletions

View File

@@ -223,6 +223,8 @@ class EagleProposer:
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
@@ -251,6 +253,8 @@ class EagleProposer:
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0