[V1][Spec Decode] Update target_logits in place for rejection sampling (#15427)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Some checks failed
Create Release / Create Release (push) Has been cancelled
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -67,6 +67,7 @@ class RejectionSampler(nn.Module):
|
|||||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||||
different requests are flattened into a single tensor because
|
different requests are flattened into a single tensor because
|
||||||
this is the shape of the output logits.
|
this is the shape of the output logits.
|
||||||
|
NOTE: `target_logits` can be updated in place to save memory.
|
||||||
bonus_token_ids_tensor (torch.Tensor):
|
bonus_token_ids_tensor (torch.Tensor):
|
||||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||||
Bonus tokens are added to the end of the sequence if all
|
Bonus tokens are added to the end of the sequence if all
|
||||||
@@ -83,6 +84,8 @@ class RejectionSampler(nn.Module):
|
|||||||
'''
|
'''
|
||||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
|
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||||
|
# `compute_probs` function.
|
||||||
target_probs = compute_probs(
|
target_probs = compute_probs(
|
||||||
target_logits,
|
target_logits,
|
||||||
metadata.cu_num_draft_tokens,
|
metadata.cu_num_draft_tokens,
|
||||||
@@ -252,8 +255,8 @@ def compute_probs(
|
|||||||
replace_from=GREEDY_TEMPERATURE,
|
replace_from=GREEDY_TEMPERATURE,
|
||||||
replace_to=1,
|
replace_to=1,
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Consider using in-place op to reduce memory usage.
|
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
|
||||||
logits = logits / temperature.unsqueeze(-1)
|
logits.div_(temperature.unsqueeze(-1))
|
||||||
|
|
||||||
# Get expanded top_k and top_p tensors.
|
# Get expanded top_k and top_p tensors.
|
||||||
top_k = None
|
top_k = None
|
||||||
|
|||||||
@@ -1059,7 +1059,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO(woosuk): Optimize the memory usage.
|
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||||
|
# creates a new tensor with separate storage from the original
|
||||||
|
# logits tensor. This means any in-place operations on bonus_logits
|
||||||
|
# won't affect the original logits tensor.
|
||||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||||
sampler_output = self.model.sample(
|
sampler_output = self.model.sample(
|
||||||
logits=bonus_logits,
|
logits=bonus_logits,
|
||||||
@@ -1067,7 +1070,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
bonus_token_ids = sampler_output.sampled_token_ids
|
bonus_token_ids = sampler_output.sampled_token_ids
|
||||||
|
|
||||||
# TODO(woosuk): Optimize the memory usage.
|
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||||
|
# separate storage from the original `logits` tensor. Therefore,
|
||||||
|
# it is safe to update `target_logits` in place.
|
||||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||||
output_token_ids = self.rejection_sampler(
|
output_token_ids = self.rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
|
|||||||
Reference in New Issue
Block a user