Revert "[Core] Performance: Use list[np.ndarray] instead of list[list… (#28773)

This commit is contained in:
Nick Hill
2025-11-14 20:24:00 -08:00
committed by GitHub
parent edfe498189
commit ac86bff8cb
12 changed files with 76 additions and 102 deletions

View File

@@ -1010,8 +1010,8 @@ class Scheduler(SchedulerInterface):
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids: list[int] = (
sampled_token_ids[req_index].tolist() if sampled_token_ids else []
generated_token_ids = (
sampled_token_ids[req_index] if sampled_token_ids else []
)
scheduled_spec_token_ids = (

View File

@@ -158,7 +158,7 @@ class ModelRunnerOutput:
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: list[np.ndarray]
sampled_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]

View File

@@ -3,7 +3,6 @@
from dataclasses import replace
import numpy as np
import torch
import torch.nn as nn
@@ -205,7 +204,7 @@ class RejectionSampler(nn.Module):
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[np.ndarray]:
) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
@@ -221,7 +220,10 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
return [row[valid_mask[i]] for i, row in enumerate(output_token_ids_np)]
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
def apply_logits_processors(
self,

View File

@@ -484,7 +484,7 @@ class EagleProposer:
def prepare_next_token_ids_cpu(
self,
sampled_token_ids: list[np.ndarray],
sampled_token_ids: list[list[int]],
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int],
@@ -499,7 +499,7 @@ class EagleProposer:
req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids.shape[0] > 0:
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
@@ -510,9 +510,10 @@ class EagleProposer:
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
return torch.tensor(
next_token_ids = torch.tensor(
next_token_ids, dtype=torch.int32, device=self.input_ids.device
)
return next_token_ids
def prepare_next_token_ids_padded(
self,

View File

@@ -54,7 +54,7 @@ class NgramProposer:
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.propose(
[np.array([])] * 1024,
[[]] * 1024,
[""] * 1024,
np.zeros(1024, dtype=np.int32),
np.zeros((1024, self.max_model_len), dtype=np.int32),
@@ -131,7 +131,7 @@ class NgramProposer:
def propose(
self,
sampled_token_ids: list[np.ndarray],
sampled_token_ids: list[list[int]],
req_ids: list[str],
num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray,
@@ -140,7 +140,7 @@ class NgramProposer:
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = sampled_ids.shape[0]
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue

View File

@@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch
@@ -34,16 +32,16 @@ class SuffixDecodingProposer:
def propose(
self,
input_batch: InputBatch,
sampled_token_ids: list[np.ndarray],
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
"""
Propose speculative tokens for each request in the input batch. Suffix Decoding
will speculate a dynamic number of tokens for each request every decoding step,
so each entry in the returned list may have different lengths.
"""
draft_token_ids: list[np.ndarray] = []
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
if sampled_ids.shape[0] == 0:
if not sampled_ids:
# Skip speculative decoding for partial prefills.
draft_token_ids.append([])
continue
@@ -72,7 +70,7 @@ class SuffixDecodingProposer:
self.suffix_cache.start_request(req_id, prompt_token_ids)
# Append the newly sampled ids to the suffix cache for this request.
self.suffix_cache.add_active_response(req_id, sampled_ids.tolist())
self.suffix_cache.add_active_response(req_id, sampled_ids)
# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
# we extract the pattern from the end of the input.

View File

@@ -216,11 +216,9 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
del self._logprobs_tensors
del self._sampled_token_ids
valid_sampled_token_ids: list[np.ndarray] = [
row for row in self.sampled_token_ids_cpu.numpy()
]
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i] = np.array([])
valid_sampled_token_ids[i].clear()
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
@@ -2341,7 +2339,7 @@ class GPUModelRunner(
) -> tuple[
dict[str, int],
LogprobsLists | None,
list[np.ndarray],
list[list[int]],
dict[str, LogprobsTensors | None],
list[str],
dict[str, int],
@@ -2367,7 +2365,6 @@ class GPUModelRunner(
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
invalid_req_indices = []
valid_sampled_token_ids: list[np.ndarray]
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -2382,7 +2379,7 @@ class GPUModelRunner(
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)] = np.array([])
valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
@@ -2410,24 +2407,19 @@ class GPUModelRunner(
[0] if spec_decode_metadata and logprobs_tensors else None
)
for req_idx in range(num_sampled_tokens):
sampled_ids: np.ndarray | None
if self.use_async_scheduling:
sampled_ids = (
np.array([-1]) if req_idx not in invalid_req_indices_set else None
)
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
num_sampled_ids: int = (
sampled_ids.shape[0] if sampled_ids is not None else 0
)
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + num_sampled_ids
)
if sampled_ids is None or num_sampled_ids == 0:
if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
@@ -2769,9 +2761,7 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
def propose_draft_token_ids(
sampled_token_ids: torch.Tensor | list[np.ndarray],
) -> None:
def propose_draft_token_ids(sampled_token_ids):
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("gpu_model_runner: draft"):
self._draft_token_ids = self.propose_draft_token_ids(
@@ -2893,14 +2883,14 @@ class GPUModelRunner(
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: torch.Tensor | list[np.ndarray],
sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
) -> torch.Tensor | list[list[int]]:
) -> list[list[int]] | torch.Tensor:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(sampled_token_ids, list)
@@ -2932,7 +2922,7 @@ class GPUModelRunner(
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens, sampled_token_ids
):
indices.append(offset + tokens.shape[0] - 1)
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
@@ -4872,7 +4862,7 @@ class GPUModelRunner(
return kv_cache_spec
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[np.ndarray]:
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
# This is a short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/22754.
# `tolist` would trigger a cuda wise stream sync, which
@@ -4885,4 +4875,4 @@ class GPUModelRunner(
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
self.transfer_event.synchronize()
return [row for row in pinned.numpy()]
return pinned.tolist()