[perf][async] support non cpu sync get logprob tensors for spec (#31336)

Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhrrr
2026-01-10 05:24:51 +08:00
committed by GitHub
parent 94578127a4
commit 97ba96fbe9
3 changed files with 48 additions and 29 deletions

View File

@@ -69,6 +69,14 @@ class LogprobsTensors(NamedTuple):
self.selected_token_ranks.to("cpu", non_blocking=True),
)
def filter(self, mask: torch.Tensor) -> "LogprobsTensors":
"""Filter the logprobs tensors with the given bool mask."""
return LogprobsTensors(
self.logprob_token_ids[mask],
self.logprobs[mask],
self.selected_token_ranks[mask],
)
@staticmethod
def empty_cpu(
num_positions: int, num_tokens_per_position: int

View File

@@ -9,7 +9,7 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties
@@ -185,13 +185,22 @@ class RejectionSampler(nn.Module):
final_logits[target_logits_indices] = target_logits.to(torch.float32)
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
# Compute accepted token indices.
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
num_accepted_tokens = accepted_mask.sum(dim=-1)
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
num_accepted_tokens
# NOTE: To avoid cpu-gpu synchronization, we now simply compute indices for
# all draft tokens, including the rejected ones. The rejected tokens will
# be filtered out in the `parse_output`.
logit_start_indices = cu_num_sampled_tokens
offsets = torch.arange(
sampled_token_ids.shape[-1],
device=logit_start_indices.device,
dtype=logit_start_indices.dtype,
)
accepted_logit_indices = (
logit_start_indices.unsqueeze(1) + offsets.unsqueeze(0)
).flatten()
accepted_logit_indices.clamp_(max=final_logits.shape[0] - 1)
accepted_tokens = sampled_token_ids.clone().flatten()
# we replace rejected token ids with 0 to avoid gather_logprobs error
accepted_tokens[accepted_tokens == PLACEHOLDER_TOKEN_ID] = 0
# Compute logprobs for accepted tokens.
accepted_logits = final_logits[accepted_logit_indices]
@@ -200,7 +209,6 @@ class RejectionSampler(nn.Module):
if self.is_logits_logprobs_mode
else self.sampler.compute_logprobs(accepted_logits)
)
accepted_tokens = sampled_token_ids[accepted_mask]
return self.sampler.gather_logprobs(
accepted_logprobs,
max_num_logprobs,
@@ -212,8 +220,8 @@ class RejectionSampler(nn.Module):
output_token_ids: torch.Tensor,
vocab_size: int,
discard_req_indices: Sequence[int] = (),
return_cu_num_tokens: bool = False,
) -> tuple[list[list[int]], list[int] | None]:
logprobs_tensors: LogprobsTensors | None = None,
) -> tuple[list[list[int]], LogprobsLists | None]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
@@ -222,7 +230,7 @@ class RejectionSampler(nn.Module):
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in.
return_cu_num_tokens: Whether to also return cumulative token counts.
logprobs_tensors: Optional logprobs tensors to filter.
Returns:
A list of lists of token IDs.
"""
@@ -231,15 +239,18 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
cu_num_tokens = None
if return_cu_num_tokens:
output_logprobs = None
if logprobs_tensors is not None:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
filtered_tensors = logprobs_tensors.filter(valid_mask.flatten())
output_logprobs = filtered_tensors.tolists(cu_num_tokens)
if len(discard_req_indices) > 0:
valid_mask[discard_req_indices] = False
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs, cu_num_tokens
return outputs, output_logprobs
def apply_logits_processors(
self,

View File

@@ -237,19 +237,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
cu_num_tokens = None
logprobs_lists = None
if self._logprobs_tensors_cpu is not None:
logprobs_lists = self._logprobs_tensors_cpu.tolists()
else:
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
self.sampled_token_ids_cpu,
self.vocab_size,
self._invalid_req_indices,
return_cu_num_tokens=self._logprobs_tensors_cpu is not None,
logprobs_tensors=self._logprobs_tensors_cpu,
)
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu:
output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
output.logprobs = logprobs_lists
return output
@@ -395,6 +396,9 @@ class GPUModelRunner(
else:
self.max_encoder_len = 0
# Async scheduling
self.use_async_scheduling = self.scheduler_config.async_scheduling
# Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
@@ -504,7 +508,6 @@ class GPUModelRunner(
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
# Separate cuda stream for overlapping transfer of sampled token ids from
# GPU to CPU when async scheduling is enabled.
self.async_output_copy_stream: torch.cuda.Stream | None = None
@@ -2784,7 +2787,7 @@ class GPUModelRunner(
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
cu_num_tokens: list[int] | None = None
logprobs_lists = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -2794,13 +2797,16 @@ class GPUModelRunner(
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
if logprobs_tensors is not None:
logprobs_lists = logprobs_tensors.tolists()
else:
# Includes spec decode tokens.
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
discard_sampled_tokens_req_indices,
return_cu_num_tokens=logprobs_tensors is not None,
logprobs_tensors=logprobs_tensors,
)
else:
valid_sampled_token_ids = []
@@ -2853,12 +2859,6 @@ class GPUModelRunner(
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (
logprobs_tensors.tolists(cu_num_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],