[Misc] Various code simplifications (#31666)

Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-01-04 18:35:56 -08:00
committed by GitHub
parent bb4337b34c
commit 43e3f8e4a9
7 changed files with 66 additions and 135 deletions

View File

@@ -10,10 +10,7 @@ logger = init_logger(__name__)
class AsyncScheduler(Scheduler):
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
@@ -41,9 +38,7 @@ class AsyncScheduler(Scheduler):
)
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
self, request: Request, new_token_ids: list[int]
) -> tuple[list[int], bool]:
if request.discard_latest_async_tokens:
# If the request is force preempted in reset_prefix_cache, we

View File

@@ -85,10 +85,7 @@ class SchedulerInterface(ABC):
raise NotImplementedError
@abstractmethod
def update_draft_token_ids(
self,
draft_token_ids: "DraftTokenIds",
) -> None:
def update_draft_token_ids(self, draft_token_ids: "DraftTokenIds") -> None:
"""Update the draft token ids for the scheduled requests."""
raise NotImplementedError

View File

@@ -762,11 +762,7 @@ class Scheduler(SchedulerInterface):
self._update_after_schedule(scheduler_output)
return scheduler_output
def _preempt_request(
self,
request: Request,
timestamp: float,
) -> None:
def _preempt_request(self, request: Request, timestamp: float) -> None:
"""Preempt a request and put it back to the waiting queue.
NOTE: The request should be popped from the running queue outside of this
@@ -786,10 +782,7 @@ class Scheduler(SchedulerInterface):
# Put the request back to the waiting queue.
self.waiting.prepend_request(request)
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
@@ -1006,8 +999,7 @@ class Scheduler(SchedulerInterface):
)
curr_embeds_start, curr_embeds_end = (
mm_feature.mm_position.get_embeds_indices_in_range(
start_idx_rel,
end_idx_rel,
start_idx_rel, end_idx_rel
)
)
# There's no embeddings in the current range of encoder placeholder tokens
@@ -1034,8 +1026,7 @@ class Scheduler(SchedulerInterface):
)
def get_grammar_bitmask(
self,
scheduler_output: SchedulerOutput,
self, scheduler_output: SchedulerOutput
) -> GrammarOutput | None:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
@@ -1285,9 +1276,7 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
self, request: Request, new_token_ids: list[int]
) -> tuple[list[int], bool]:
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
@@ -1328,10 +1317,7 @@ class Scheduler(SchedulerInterface):
# in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(request, input_id)
def update_draft_token_ids(
self,
draft_token_ids: DraftTokenIds,
) -> None:
def update_draft_token_ids(self, draft_token_ids: DraftTokenIds) -> None:
for req_id, spec_token_ids in zip(
draft_token_ids.req_ids,
draft_token_ids.draft_token_ids,
@@ -1361,9 +1347,7 @@ class Scheduler(SchedulerInterface):
request.record_event(EngineCoreEventType.QUEUED)
def finish_requests(
self,
request_ids: str | Iterable[str],
finished_status: RequestStatus,
self, request_ids: str | Iterable[str], finished_status: RequestStatus
) -> None:
"""Handles the finish signal from outside the scheduler.

View File

@@ -204,10 +204,7 @@ class EagleProposer:
)
# Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange(
1,
len(self.tree_choices) + 1,
device=device,
dtype=torch.int32,
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1)
def _get_positions(self, num_tokens: int):
@@ -287,8 +284,7 @@ class EagleProposer:
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens,
num_tokens_padded=num_tokens,
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode = CUDAGraphMode.NONE
@@ -391,8 +387,7 @@ class EagleProposer:
draft_token_ids_list = [draft_token_ids]
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size,
num_tokens_padded=batch_size,
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
)
if (
@@ -610,10 +605,8 @@ class EagleProposer:
assert discard_request_mask.dtype == torch.bool
assert backup_tokens_gpu.dtype == torch.int32
next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device)
valid_sampled_tokens_count = torch.empty(
(batch_size,), dtype=torch.int32, device=device
)
next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
# Kernel grid: one program per request (row)
grid = (batch_size,)
@@ -782,8 +775,7 @@ class EagleProposer:
max_query_len=query_len,
)
attn_metadata = tree_attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=level + 1,
common_attn_metadata=common_attn_metadata, draft_index=level + 1
)
# Apply new attention metadata to all layers.
@@ -1161,8 +1153,8 @@ class EagleProposer:
def dummy_run(
self,
num_tokens: int,
use_cudagraphs=True,
is_graph_capturing=False,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
) -> None:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
@@ -1174,8 +1166,7 @@ class EagleProposer:
):
if fwd_idx <= 1:
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens,
num_tokens_padded=num_tokens,
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
if (
cudagraphs_enabled
@@ -1342,9 +1333,5 @@ def compute_probs_and_sample_next_token(
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
if not sampling_metadata.all_random:
greedy_token_ids = probs.argmax(dim=-1)
next_token_ids = torch.where(
is_greedy,
greedy_token_ids,
next_token_ids,
)
next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
return next_token_ids, probs

View File

@@ -28,8 +28,6 @@ if TYPE_CHECKING:
else:
torch = LazyLoader("torch", globals(), "torch")
ReasoningParser = object
Request = object
logger = init_logger(__name__)
@@ -98,7 +96,7 @@ class StructuredOutputManager:
self.vllm_config.structured_outputs_config.enable_in_reasoning
)
def grammar_init(self, request: Request) -> None:
def grammar_init(self, request: "Request") -> None:
if request.structured_output_request is None:
return
@@ -156,10 +154,7 @@ class StructuredOutputManager:
grammar = self._create_grammar(request) # type: ignore[assignment]
request.structured_output_request.grammar = grammar # type: ignore[assignment]
def _create_grammar(
self,
request: Request,
) -> StructuredOutputGrammar:
def _create_grammar(self, request: "Request") -> StructuredOutputGrammar:
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
# Note that the request was validated in the engine core client,
@@ -173,8 +168,7 @@ class StructuredOutputManager:
return self.backend.compile_grammar(request_type, grammar_spec)
def _fill_bitmasks(
self,
batch: Iterable[tuple[StructuredOutputGrammar, int, bool]],
self, batch: Iterable[tuple[StructuredOutputGrammar, int, bool]]
) -> None:
assert self._grammar_bitmask is not None
for grammar, index, apply_bitmask in batch:
@@ -187,14 +181,13 @@ class StructuredOutputManager:
self._grammar_bitmask[index].fill_(self._full_mask)
def _async_submit_fill_bitmask(
self,
batch: list[tuple[StructuredOutputGrammar, int, bool]],
self, batch: list[tuple[StructuredOutputGrammar, int, bool]]
) -> Future:
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)
def grammar_bitmask(
self,
requests: dict[str, Request],
requests: dict[str, "Request"],
structured_output_request_ids: list[str],
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> "npt.NDArray[np.int32] | None":
@@ -239,11 +232,10 @@ class StructuredOutputManager:
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
grammar = structured_output_request.grammar
apply_bitmask = self.should_fill_bitmask(request)
batch.append(
(structured_output_request.grammar, cumulative_index, apply_bitmask)
)
batch.append((grammar, cumulative_index, apply_bitmask))
if len(batch) == self.fill_bitmask_parallel_batch_size:
promises.append(self._async_submit_fill_bitmask(batch))
batch = []
@@ -264,34 +256,23 @@ class StructuredOutputManager:
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
grammar = structured_output_request.grammar
apply_bitmask = self.should_fill_bitmask(request)
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, ())
for token in itertools.chain(req_tokens, (None,)):
self._fill_bitmasks(
(
(
structured_output_request.grammar,
cumulative_index,
apply_bitmask,
),
)
)
if (
apply_bitmask
and token is not None
and not structured_output_request.grammar.is_terminated()
):
accepted = structured_output_request.grammar.accept_tokens(
req_id, [token]
)
for token in itertools.chain(req_tokens, (-1,)):
self._fill_bitmasks(((grammar, cumulative_index, apply_bitmask),))
if token == -1:
# Stop advancing the grammar once we hit a padding token.
apply_bitmask = False
if apply_bitmask and not grammar.is_terminated():
accepted = grammar.accept_tokens(req_id, [token])
assert accepted, (token, req_id, scheduled_spec_decode_tokens)
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
structured_output_request.grammar.rollback(state_advancements)
grammar.rollback(state_advancements)
bitmask_tensor = self._grammar_bitmask
if cumulative_index < bitmask_tensor.shape[0]:
@@ -302,7 +283,7 @@ class StructuredOutputManager:
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()
def should_fill_bitmask(self, request: Request) -> bool:
def should_fill_bitmask(self, request: "Request") -> bool:
# NOTE (Hanchen) if enable_in_reasoning is True, it means that
# the model needs to be constrained in reasoning. So we should always
# enable the bitmask filling.
@@ -318,7 +299,7 @@ class StructuredOutputManager:
return request.structured_output_request.reasoning_ended
return True
def should_advance(self, request: Request) -> bool:
def should_advance(self, request: "Request") -> bool:
if not request.use_structured_output:
return False

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import hashlib
import importlib.metadata
import os
import tempfile
from typing import TYPE_CHECKING
import numpy as np
@@ -34,9 +35,6 @@ else:
"convert_slow_tokenizer", globals(), "transformers.convert_slow_tokenizer"
)
TokenizerLike = object
SchedulerOutput = object
InputBatch = object
logger = init_logger(__name__)
@@ -72,13 +70,12 @@ def apply_grammar_bitmask(
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0
seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1])
for req_id, batch_index in seq:
spec_tokens = scheduler_output.scheduled_spec_decode_tokens
struct_out_req_ids = set(grammar_output.structured_output_request_ids)
for batch_index, req_id in enumerate(input_batch.req_ids):
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
if req_id in grammar_output.structured_output_request_ids:
cumulative_offset += len(spec_tokens.get(req_id, ()))
if req_id in struct_out_req_ids:
struct_out_req_batch_indices[req_id] = logit_index
out_indices = []
@@ -91,14 +88,12 @@ def apply_grammar_bitmask(
)
cumulative_index = 0
for req_id in grammar_output.structured_output_request_ids:
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
if req_id in struct_out_req_batch_indices:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(spec_tokens.get(req_id, ()))
if (logit_idx := struct_out_req_batch_indices.get(req_id)) is not None:
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
bitmask_index = logit_idx + i
sorted_bitmask[bitmask_index] = grammar_bitmask[cumulative_index + i]
out_indices.append(bitmask_index)
cumulative_index += 1 + num_spec_tokens
# Copy async to device as tensor.
@@ -149,21 +144,19 @@ def get_outlines_cache_path() -> str:
if outlines_cache_dir:
# OUTLINES_CACHE_DIR takes precedence
return outlines_cache_dir
elif xdg_cache_home:
if xdg_cache_home:
return os.path.join(xdg_cache_home, ".cache", "outlines")
# If homedir is "/", we may be inside a container, and thus writing to
# root would be problematic, so we fall back to using a tempfile.
# Also validate the path exists, since os.path.expanduser does
# not guarantee existence.
elif os.path.isdir(home_dir) and home_dir != "/":
if os.path.isdir(home_dir) and home_dir != "/":
# Default Unix fallback: ~/.cache/outlines
return os.path.join(home_dir, ".cache", "outlines")
else:
import tempfile
# home_dir may be / inside a docker container without existing user
tempdir = tempfile.gettempdir()
return os.path.join(tempdir, ".cache", "outlines")
# home_dir may be / inside a docker container without existing user
tempdir = tempfile.gettempdir()
return os.path.join(tempdir, ".cache", "outlines")
def get_outlines_cache():
@@ -184,8 +177,8 @@ def get_outlines_cache():
cache.clear()
cache.set("__version__", outlines_version)
return cache
else:
return LRUCache(maxsize=128)
return LRUCache(maxsize=128)
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
@@ -193,8 +186,7 @@ re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
def _reduced_vocabulary(
tokenizer: TokenizerLike,
eos_token_id: int,
tokenizer: TokenizerLike, eos_token_id: int
) -> dict[bytes, list[int]]:
"""Create a map from vocabulary tokens to lists of equivalent token ids.
@@ -267,17 +259,13 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
return tokenizer._outlines_vocabulary # type: ignore
try:
if (
hasattr(
tokenizer,
"eos_token_id",
)
and tokenizer.eos_token_id is not None
):
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
eos_token_id = tokenizer.eos_token_id
else:
raise ValueError(
f"Error during structured outputs setup for outlines: Tokenizer ({type(tokenizer)}) has no `eos_token_id` property, but `eos_token_id` is required for structured outputs to work properly." # noqa: E501
"Error during structured outputs setup for outlines: Tokenizer "
f"({type(tokenizer)}) has no `eos_token_id` property, but "
"`eos_token_id` is required for structured outputs to work properly."
)
reduced_vocab = _reduced_vocabulary(
@@ -290,7 +278,7 @@ def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
return vocabulary
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
"Cannot get the vocabulary of the tokenizer "
f"({type(tokenizer)}). The tokenizer should have a "
"get_vocab method."
) from e

View File

@@ -3564,14 +3564,13 @@ class GPUModelRunner(
def _get_valid_sampled_token_count(self) -> list[int]:
# Wait until valid_sampled_tokens_count is copied to cpu,
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
if (
self.valid_sampled_token_count_event is None
or prev_sampled_token_ids is None
):
sampled_count_event = self.valid_sampled_token_count_event
if sampled_count_event is None or prev_sampled_token_ids is None:
return []
counts_cpu = self.valid_sampled_token_count_cpu
self.valid_sampled_token_count_event.synchronize()
assert counts_cpu is not None
sampled_count_event.synchronize()
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
def propose_draft_token_ids(