[V1][Spec Decode] Update target_logits in place for rejection sampling (#15427)
Some checks failed
Create Release / Create Release (push) Has been cancelled

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-03-24 21:04:41 -07:00
committed by GitHub
parent a09ad90a72
commit 25f560a62c
2 changed files with 12 additions and 4 deletions

View File

@@ -1059,7 +1059,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
)
else:
# TODO(woosuk): Optimize the memory usage.
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.model.sample(
logits=bonus_logits,
@@ -1067,7 +1070,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
bonus_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): Optimize the memory usage.
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,