[perf][async] support non cpu sync get logprob tensors for spec (#31336)
Signed-off-by: izhuhaoran <izhuhaoran@qq.com> Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
@@ -69,6 +69,14 @@ class LogprobsTensors(NamedTuple):
|
||||
self.selected_token_ranks.to("cpu", non_blocking=True),
|
||||
)
|
||||
|
||||
def filter(self, mask: torch.Tensor) -> "LogprobsTensors":
|
||||
"""Filter the logprobs tensors with the given bool mask."""
|
||||
return LogprobsTensors(
|
||||
self.logprob_token_ids[mask],
|
||||
self.logprobs[mask],
|
||||
self.selected_token_ranks[mask],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def empty_cpu(
|
||||
num_positions: int, num_tokens_per_position: int
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
@@ -185,13 +185,22 @@ class RejectionSampler(nn.Module):
|
||||
final_logits[target_logits_indices] = target_logits.to(torch.float32)
|
||||
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
|
||||
|
||||
# Compute accepted token indices.
|
||||
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
|
||||
num_accepted_tokens = accepted_mask.sum(dim=-1)
|
||||
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
|
||||
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
|
||||
num_accepted_tokens
|
||||
# NOTE: To avoid cpu-gpu synchronization, we now simply compute indices for
|
||||
# all draft tokens, including the rejected ones. The rejected tokens will
|
||||
# be filtered out in the `parse_output`.
|
||||
logit_start_indices = cu_num_sampled_tokens
|
||||
offsets = torch.arange(
|
||||
sampled_token_ids.shape[-1],
|
||||
device=logit_start_indices.device,
|
||||
dtype=logit_start_indices.dtype,
|
||||
)
|
||||
accepted_logit_indices = (
|
||||
logit_start_indices.unsqueeze(1) + offsets.unsqueeze(0)
|
||||
).flatten()
|
||||
accepted_logit_indices.clamp_(max=final_logits.shape[0] - 1)
|
||||
accepted_tokens = sampled_token_ids.clone().flatten()
|
||||
# we replace rejected token ids with 0 to avoid gather_logprobs error
|
||||
accepted_tokens[accepted_tokens == PLACEHOLDER_TOKEN_ID] = 0
|
||||
|
||||
# Compute logprobs for accepted tokens.
|
||||
accepted_logits = final_logits[accepted_logit_indices]
|
||||
@@ -200,7 +209,6 @@ class RejectionSampler(nn.Module):
|
||||
if self.is_logits_logprobs_mode
|
||||
else self.sampler.compute_logprobs(accepted_logits)
|
||||
)
|
||||
accepted_tokens = sampled_token_ids[accepted_mask]
|
||||
return self.sampler.gather_logprobs(
|
||||
accepted_logprobs,
|
||||
max_num_logprobs,
|
||||
@@ -212,8 +220,8 @@ class RejectionSampler(nn.Module):
|
||||
output_token_ids: torch.Tensor,
|
||||
vocab_size: int,
|
||||
discard_req_indices: Sequence[int] = (),
|
||||
return_cu_num_tokens: bool = False,
|
||||
) -> tuple[list[list[int]], list[int] | None]:
|
||||
logprobs_tensors: LogprobsTensors | None = None,
|
||||
) -> tuple[list[list[int]], LogprobsLists | None]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
@@ -222,7 +230,7 @@ class RejectionSampler(nn.Module):
|
||||
and will be filtered out in this function.
|
||||
vocab_size: The size of the vocabulary.
|
||||
discard_req_indices: Optional row indices to discard tokens in.
|
||||
return_cu_num_tokens: Whether to also return cumulative token counts.
|
||||
logprobs_tensors: Optional logprobs tensors to filter.
|
||||
Returns:
|
||||
A list of lists of token IDs.
|
||||
"""
|
||||
@@ -231,15 +239,18 @@ class RejectionSampler(nn.Module):
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
cu_num_tokens = None
|
||||
if return_cu_num_tokens:
|
||||
output_logprobs = None
|
||||
if logprobs_tensors is not None:
|
||||
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
|
||||
filtered_tensors = logprobs_tensors.filter(valid_mask.flatten())
|
||||
output_logprobs = filtered_tensors.tolists(cu_num_tokens)
|
||||
|
||||
if len(discard_req_indices) > 0:
|
||||
valid_mask[discard_req_indices] = False
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs, cu_num_tokens
|
||||
return outputs, output_logprobs
|
||||
|
||||
def apply_logits_processors(
|
||||
self,
|
||||
|
||||
@@ -237,19 +237,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
|
||||
for i in self._invalid_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
cu_num_tokens = None
|
||||
logprobs_lists = None
|
||||
if self._logprobs_tensors_cpu is not None:
|
||||
logprobs_lists = self._logprobs_tensors_cpu.tolists()
|
||||
else:
|
||||
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
|
||||
valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
|
||||
self.sampled_token_ids_cpu,
|
||||
self.vocab_size,
|
||||
self._invalid_req_indices,
|
||||
return_cu_num_tokens=self._logprobs_tensors_cpu is not None,
|
||||
logprobs_tensors=self._logprobs_tensors_cpu,
|
||||
)
|
||||
|
||||
output = self._model_runner_output
|
||||
output.sampled_token_ids = valid_sampled_token_ids
|
||||
if self._logprobs_tensors_cpu:
|
||||
output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
|
||||
output.logprobs = logprobs_lists
|
||||
return output
|
||||
|
||||
|
||||
@@ -395,6 +396,9 @@ class GPUModelRunner(
|
||||
else:
|
||||
self.max_encoder_len = 0
|
||||
|
||||
# Async scheduling
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
|
||||
# Sampler
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
@@ -504,7 +508,6 @@ class GPUModelRunner(
|
||||
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
# Separate cuda stream for overlapping transfer of sampled token ids from
|
||||
# GPU to CPU when async scheduling is enabled.
|
||||
self.async_output_copy_stream: torch.cuda.Stream | None = None
|
||||
@@ -2784,7 +2787,7 @@ class GPUModelRunner(
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
invalid_req_indices = []
|
||||
cu_num_tokens: list[int] | None = None
|
||||
logprobs_lists = None
|
||||
if not self.use_async_scheduling:
|
||||
# Get the valid generated tokens.
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
@@ -2794,13 +2797,16 @@ class GPUModelRunner(
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
|
||||
if logprobs_tensors is not None:
|
||||
logprobs_lists = logprobs_tensors.tolists()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
|
||||
valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
|
||||
sampled_token_ids,
|
||||
self.input_batch.vocab_size,
|
||||
discard_sampled_tokens_req_indices,
|
||||
return_cu_num_tokens=logprobs_tensors is not None,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
else:
|
||||
valid_sampled_token_ids = []
|
||||
@@ -2853,12 +2859,6 @@ class GPUModelRunner(
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
logprobs_lists = (
|
||||
logprobs_tensors.tolists(cu_num_tokens)
|
||||
if not self.use_async_scheduling and logprobs_tensors is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
|
||||
Reference in New Issue
Block a user