[V1][Spec Decode] Update target_logits in place for rejection sampling (#15427)
Some checks failed
Create Release / Create Release (push) Has been cancelled

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-03-24 21:04:41 -07:00
committed by GitHub
parent a09ad90a72
commit 25f560a62c
2 changed files with 12 additions and 4 deletions

View File

@@ -67,6 +67,7 @@ class RejectionSampler(nn.Module):
Shape is [num_tokens, vocab_size]. Here, probabilities from Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because different requests are flattened into a single tensor because
this is the shape of the output logits. this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor): bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1]. A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all Bonus tokens are added to the end of the sequence if all
@@ -83,6 +84,8 @@ class RejectionSampler(nn.Module):
''' '''
assert metadata.max_spec_len <= MAX_SPEC_LEN assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_probs = compute_probs( target_probs = compute_probs(
target_logits, target_logits,
metadata.cu_num_draft_tokens, metadata.cu_num_draft_tokens,
@@ -252,8 +255,8 @@ def compute_probs(
replace_from=GREEDY_TEMPERATURE, replace_from=GREEDY_TEMPERATURE,
replace_to=1, replace_to=1,
) )
# TODO(woosuk): Consider using in-place op to reduce memory usage. # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits = logits / temperature.unsqueeze(-1) logits.div_(temperature.unsqueeze(-1))
# Get expanded top_k and top_p tensors. # Get expanded top_k and top_p tensors.
top_k = None top_k = None

View File

@@ -1059,7 +1059,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
else: else:
# TODO(woosuk): Optimize the memory usage. # When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.model.sample( sampler_output = self.model.sample(
logits=bonus_logits, logits=bonus_logits,
@@ -1067,7 +1070,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
bonus_token_ids = sampler_output.sampled_token_ids bonus_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): Optimize the memory usage. # Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices] target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler( output_token_ids = self.rejection_sampler(
spec_decode_metadata, spec_decode_metadata,