[V1][Spec Decode] Update target_logits in place for rejection sampling (#15427)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Some checks failed
Create Release / Create Release (push) Has been cancelled
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -1059,7 +1059,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
# TODO(woosuk): Optimize the memory usage.
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.model.sample(
|
||||
logits=bonus_logits,
|
||||
@@ -1067,7 +1070,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# TODO(woosuk): Optimize the memory usage.
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
|
||||
Reference in New Issue
Block a user