[V1][Bugfix][Spec Decode] Fix incorrect outputs in V1 speculative decoding due to batch indexing (#14645)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
committed by
GitHub
parent
e22ee1e7a2
commit
5c538c37b2
@@ -1015,11 +1015,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
target_probs = self.model.sampler.compute_probs(
|
||||
logits, sampling_metadata)
|
||||
scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
|
||||
)
|
||||
draft_token_ids = [
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
for req_id in scheduled_request_ids
|
||||
for req_id in self.input_batch.req_ids
|
||||
]
|
||||
sampler_output = self.rejection_sampler(draft_token_ids,
|
||||
target_probs,
|
||||
|
||||
Reference in New Issue
Block a user