diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 278d421eb..d3e857345 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -271,7 +271,7 @@ class RejectionSampler(nn.Module): # Calculate indices of target logits. if sampling_metadata.allowed_token_ids_mask is not None or has_penalties: - num_requests = len(sampling_metadata.output_token_ids) + num_requests = len(metadata.num_draft_tokens) num_draft_tokens = torch.tensor(metadata.num_draft_tokens, device="cpu") original_indices = torch.arange(num_requests, device="cpu") repeat_indices_cpu = original_indices.repeat_interleave(num_draft_tokens)