[cohere][fix][spec-decode]: fix crash when allowed_token_ids is set without penalties (#35654)
Signed-off-by: kkt-cohere <komal@cohere.com>
This commit is contained in:
@@ -271,7 +271,7 @@ class RejectionSampler(nn.Module):
|
||||
|
||||
# Calculate indices of target logits.
|
||||
if sampling_metadata.allowed_token_ids_mask is not None or has_penalties:
|
||||
num_requests = len(sampling_metadata.output_token_ids)
|
||||
num_requests = len(metadata.num_draft_tokens)
|
||||
num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu")
|
||||
original_indices = torch.arange(num_requests, device="cpu")
|
||||
repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)
|
||||
|
||||
Reference in New Issue
Block a user