[cohere][fix][spec-decode]: fix crash when allowed_token_ids is set without penalties (#35654)

Signed-off-by: kkt-cohere <komal@cohere.com>
This commit is contained in:
Komal Kumar Teru
2026-03-04 12:50:15 +05:30
committed by GitHub
parent 097eb544e9
commit 9e0f44bec4

View File

@@ -271,7 +271,7 @@ class RejectionSampler(nn.Module):
# Calculate indices of target logits.
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
num_requests = len(sampling_metadata.output_token_ids)
num_requests = len(metadata.num_draft_tokens)
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
original_indices = torch.arange(num_requests, device="cpu")
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)