[V1][Bugfix][Spec Decode] Fix incorrect outputs in V1 speculative decoding due to batch indexing (#14645)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
Benjamin Chislett
2025-03-12 01:12:41 -04:00
committed by GitHub
parent e22ee1e7a2
commit 5c538c37b2
2 changed files with 50 additions and 15 deletions

View File

@@ -1015,11 +1015,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
target_probs = self.model.sampler.compute_probs(
logits, sampling_metadata)
scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
)
draft_token_ids = [
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
for req_id in scheduled_request_ids
for req_id in self.input_batch.req_ids
]
sampler_output = self.rejection_sampler(draft_token_ids,
target_probs,