[ModelRunner V2] Misc minor simplifications and optimizations (#33467)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-01 14:17:14 -08:00
committed by GitHub
parent 0b225fb7b2
commit e535d90deb
21 changed files with 86 additions and 220 deletions

View File

@@ -4,11 +4,7 @@
import numpy as np
import torch
from vllm.v1.outputs import (
AsyncModelRunnerOutput,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.outputs import AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput
from vllm.v1.worker.gpu.sample.output import SamplerOutput

View File

@@ -32,9 +32,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
def init_attn_backend(
kv_cache_config: KVCacheConfig,
vllm_config: VllmConfig,
device: torch.device,
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
):
attn_backends: dict[str, type[AttentionBackend]] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = []
@@ -50,10 +48,7 @@ def init_attn_backend(
attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec,
layer_names,
vllm_config,
device,
kv_cache_group_spec.kv_cache_spec, layer_names, vllm_config, device
)
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
@@ -65,10 +60,7 @@ def init_attn_backend(
return attn_backends, attn_metadata_builders
def _allocate_kv_cache(
kv_cache_config: KVCacheConfig,
device: torch.device,
):
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
@@ -141,12 +133,11 @@ def init_kv_cache(
def build_slot_mappings_by_layer(
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig
) -> dict[str, torch.Tensor]:
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
slot_mapping = slot_mappings[i]
kv_cache_groups = kv_cache_config.kv_cache_groups
for slot_mapping, kv_cache_group in zip(slot_mappings, kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
slot_mappings_by_layer[layer_name] = slot_mapping
return slot_mappings_by_layer
@@ -188,8 +179,7 @@ def build_attn_metadata(
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata

View File

@@ -71,9 +71,7 @@ class BlockTables:
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
return torch.tensor(
[t.data_ptr() for t in x],
dtype=torch.uint64,
device=self.device,
[t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
)
def append_block_ids(
@@ -96,8 +94,7 @@ class BlockTables:
self.num_blocks.copy_to_uva()
def gather_block_tables(
self,
idx_mapping: torch.Tensor,
self, idx_mapping: torch.Tensor
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence
from functools import partial
import numpy as np
import torch
@@ -81,10 +82,7 @@ class UvaBufferPool:
class UvaBackedTensor:
def __init__(
self,
size: int | Sequence[int],
dtype: torch.dtype,
max_concurrency: int = 2,
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
):
self.dtype = dtype
self.max_concurrency = max_concurrency
@@ -135,25 +133,16 @@ class StagedWriteTensor:
self._staged_write_contents: list[int | float] = []
self._staged_write_cu_lens: list[int] = []
self.write_indices = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
self.write_starts = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
new_buffer = partial(UvaBufferPool, max_concurrency=max_concurrency)
self.write_indices = new_buffer(self.num_rows, dtype=torch.int32)
self.write_starts = new_buffer(self.num_rows, dtype=torch.int32)
init_size = next_power_of_2(self.num_rows)
self.write_contents = UvaBufferPool(
init_size, dtype=dtype, max_concurrency=max_concurrency
)
self.write_cu_lens = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
self.write_contents = new_buffer(init_size, dtype=dtype)
self.write_cu_lens = new_buffer(self.num_rows, dtype=torch.int32)
def stage_write(
self,
index: int,
start: int,
x: Iterable[int] | Iterable[float],
self, index: int, start: int, x: Iterable[int] | Iterable[float]
) -> None:
assert index >= 0
assert start >= 0

View File

@@ -24,12 +24,7 @@ from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
uses_mrope: bool,
device: torch.device,
):
def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope
@@ -41,11 +36,7 @@ class CudaGraphManager:
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode: CUDAGraphMode
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,

View File

@@ -13,10 +13,7 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
def get_batch_metadata_across_dp(
num_tokens: int,
cudagraph_size: int,
dp_size: int,
dp_rank: int,
num_tokens: int, cudagraph_size: int, dp_size: int, dp_rank: int
) -> tuple[torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
@@ -29,10 +26,7 @@ def get_batch_metadata_across_dp(
def get_cudagraph_and_dp_padding(
num_tokens: int,
cudagraph_size: int | None,
dp_size: int,
dp_rank: int,
num_tokens: int, cudagraph_size: int | None, dp_size: int, dp_rank: int
) -> tuple[bool, int, torch.Tensor | None]:
if dp_size == 1:
if cudagraph_size is not None:

View File

@@ -65,10 +65,10 @@ class ActiveKVConnector(KVConnector):
if scheduler_output.preempted_req_ids:
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
assert scheduler_output.kv_connector_metadata is not None
self.kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata
)
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
# TODO: sort out KV Connectors' use of forward_context
if is_forward_context_available():
self.kv_connector.start_load_kv(get_forward_context())

View File

@@ -15,10 +15,7 @@ class LoraState:
self.lora_requests: dict[str, LoRARequest] = {}
def add_request(
self,
req_id: str,
req_index: int,
lora_request: LoRARequest | None,
self, req_id: str, req_index: int, lora_request: LoRARequest | None
) -> None:
if lora_request is not None:
self.lora_requests[req_id] = lora_request
@@ -41,7 +38,7 @@ class LoraState:
active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids:
lora_request = self.lora_requests.get(req_id, None)
lora_request = self.lora_requests.get(req_id)
if lora_request is not None:
active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests

View File

@@ -23,10 +23,7 @@ class EncoderRunner:
self.device = device
self.inputs_embeds = torch.zeros(
max_num_tokens,
hidden_size,
dtype=dtype,
device=device,
max_num_tokens, hidden_size, dtype=dtype, device=device
)
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
@@ -57,8 +54,7 @@ class EncoderRunner:
self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs(
self,
scheduled_encoder_inputs: dict[str, list[int]],
self, scheduled_encoder_inputs: dict[str, list[int]]
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
mm_hashes: list[str] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
@@ -85,20 +81,16 @@ class EncoderRunner:
encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=False,
mm_kwargs, device=self.device, pin_memory=False
):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=num_items,
curr_group_outputs, expected_num_items=num_items
)
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
for mm_hash, output in zip(mm_hashes, encoder_outputs):
self.encoder_cache[mm_hash] = output
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
return encoder_outputs
def gather_mm_embeddings(
@@ -115,9 +107,7 @@ class EncoderRunner:
if all_decode:
# All decode requests, so no need to gather any embeddings.
return [], torch.zeros(
total_num_scheduled_tokens,
dtype=torch.bool,
device=self.device,
total_num_scheduled_tokens, dtype=torch.bool, device=self.device
)
query_start = computed_prefill_lens.tolist()
@@ -125,10 +115,7 @@ class EncoderRunner:
mm_embeds: list[torch.Tensor] = []
is_mm_embed = torch.zeros(
total_num_scheduled_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=True,
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
)
for i, req_id in enumerate(req_ids):
if not is_prefilling[i]:
@@ -189,9 +176,7 @@ class EncoderRunner:
is_mm_embed: torch.Tensor,
) -> torch.Tensor:
x = model.embed_input_ids(
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
)
# Copy to the pre-allocated buffer for CUDA graphs.
self.inputs_embeds[: x.shape[0]] = x

View File

@@ -51,10 +51,7 @@ class MRopeState:
mm_features: list,
) -> None:
prefill_mrope_positions, prefill_mrope_delta = (
mrope_model.get_mrope_input_positions(
prefill_token_ids,
mm_features,
)
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
)
for i in range(3):
pos = prefill_mrope_positions[i].tolist()

View File

@@ -339,10 +339,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
gc.collect()
def reset_mm_cache(self) -> None:
self.encoder_runner.reset_mm_cache()
if self.supports_mm_inputs:
self.encoder_runner.reset_mm_cache()
def reset_encoder_cache(self) -> None:
self.encoder_runner.reset_encoder_cache()
if self.supports_mm_inputs:
self.encoder_runner.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet.
@@ -402,10 +404,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
finished_req_ids = scheduler_output.finished_req_ids
if scheduler_output.preempted_req_ids:
finished_req_ids = finished_req_ids.union(
scheduler_output.preempted_req_ids
)
preempted_req_ids = scheduler_output.preempted_req_ids
if preempted_req_ids:
finished_req_ids = finished_req_ids.union(preempted_req_ids)
for req_id in finished_req_ids:
self.req_states.remove_request(req_id)
if self.supports_mm_inputs:
@@ -477,28 +478,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
num_tokens_per_req = scheduler_output.num_scheduled_tokens
num_reqs = len(num_tokens_per_req)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(
scheduler_output.num_scheduled_tokens.keys(),
key=lambda k: scheduler_output.num_scheduled_tokens[k],
)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
)
req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get) # type: ignore[arg-type]
numtoks_iter = map(num_tokens_per_req.get, req_ids)
num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32)
idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
# Get the number of draft tokens for each request.
@@ -889,8 +883,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode()
def sample_tokens(
self,
grammar_output: GrammarOutput | None,
self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput:
assert self.execute_model_state is not None
hidden_states, input_batch, kv_connector_output = self.execute_model_state

View File

@@ -13,11 +13,7 @@ MAX_NUM_STOP_TOKEN_IDS = 128
class LogitBiasState:
def __init__(
self,
max_num_reqs: int,
device: torch.device,
):
def __init__(self, max_num_reqs: int, device: torch.device):
self.max_num_reqs = max_num_reqs
# Allowed token IDs.
@@ -54,10 +50,7 @@ class LogitBiasState:
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
def add_request(
self,
req_idx: int,
prompt_len: int,
sampling_params: SamplingParams,
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
) -> None:
# Using any logit bias.
use_logit_bias = False

View File

@@ -73,19 +73,12 @@ def _ranks_kernel(
def compute_token_logprobs(
logits: torch.Tensor,
token_ids: torch.Tensor,
logits: torch.Tensor, token_ids: torch.Tensor
) -> torch.Tensor:
batch_size = logits.shape[0]
vocab_size = logits.shape[1]
batch_size, vocab_size = logits.shape
token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1]
logprobs = torch.empty(
batch_size,
num_logprobs,
dtype=torch.float32,
device=logits.device,
)
logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
_topk_log_softmax_kernel[(batch_size,)](
logprobs,
logits,
@@ -107,23 +100,16 @@ def compute_topk_logprobs(
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
if num_logprobs == 0:
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
else:
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
if num_logprobs > 0:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
)
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty(
batch_size,
dtype=torch.int64,
device=logits.device,
)
token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
_ranks_kernel[(batch_size,)](
token_ranks,
logits,

View File

@@ -42,9 +42,7 @@ def _min_p_kernel(
def apply_min_p(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
min_p: torch.Tensor,
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024

View File

@@ -23,11 +23,9 @@ class PromptLogprobsWorker:
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
# For now, only support prompt logprobs for the prompt tokens (not top-k).
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
if uses_prompt_logprobs:
self.uses_prompt_logprobs[req_idx] = True
self.in_progress_prompt_logprobs[req_id] = []
else:
self.uses_prompt_logprobs[req_idx] = False
def remove_request(self, req_id: str) -> None:
self.in_progress_prompt_logprobs.pop(req_id, None)

View File

@@ -26,7 +26,7 @@ class Sampler:
device: torch.device,
logprobs_mode: LogprobsMode = "raw_logprobs",
):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
@@ -36,10 +36,7 @@ class Sampler:
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
def add_request(
self,
req_idx: int,
prompt_len: int,
sampling_params: SamplingParams,
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
) -> None:
self.sampling_states.add_request(req_idx, sampling_params)
self.penalties_state.add_request(req_idx, sampling_params)
@@ -74,11 +71,8 @@ class Sampler:
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS:
logits = (
processed_logits
if self.logprobs_mode == "processed_logprobs"
else logits
)
if self.logprobs_mode == "processed_logprobs":
logits = processed_logits
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
logprobs_tensors = compute_topk_logprobs(

View File

@@ -35,22 +35,19 @@ class SamplingStates:
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = sampling_params.top_k
if top_k <= 0 or top_k > self.vocab_size:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = sampling_params.seed
if seed is None:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = sampling_params.logprobs
if num_logprobs is None:
num_logprobs = NO_LOGPROBS
self.num_logprobs[req_idx] = num_logprobs

View File

@@ -5,10 +5,7 @@ import torch
from vllm.config import VllmConfig
def init_speculator(
vllm_config: VllmConfig,
device: torch.device,
):
def init_speculator(vllm_config: VllmConfig, device: torch.device):
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
if speculative_config.use_eagle():

View File

@@ -54,26 +54,15 @@ class EagleSpeculator:
device=device,
)
self.hidden_states = torch.zeros(
self.max_num_tokens,
self.hidden_size,
dtype=self.dtype,
device=device,
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
)
self.idx_mapping = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=device,
self.max_num_reqs, dtype=torch.int32, device=device
)
self.temperature = torch.zeros(
self.max_num_reqs,
dtype=torch.float32,
device=device,
)
self.seeds = torch.zeros(
self.max_num_reqs,
dtype=torch.int64,
device=device,
self.max_num_reqs, dtype=torch.float32, device=device
)
self.seeds = torch.zeros(self.max_num_reqs, dtype=torch.int64, device=device)
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,

View File

@@ -19,11 +19,7 @@ from vllm.v1.worker.gpu.input_batch import InputBuffers
class EagleCudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
@@ -35,16 +31,10 @@ class EagleCudaGraphManager:
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
cudagraph_mode: CUDAGraphMode
if self.compilation_config.cudagraph_mode is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode == CUDAGraphMode.FULL:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
self.cudagraph_mode = cudagraph_mode
self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.cudagraph_mode == CUDAGraphMode.FULL:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,

View File

@@ -10,12 +10,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch
class StructuredOutputsWorker:
def __init__(
self,
max_num_logits: int,
vocab_size: int,
device: torch.device,
):
def __init__(self, max_num_logits: int, vocab_size: int, device: torch.device):
self.logits_indices = torch.zeros(
max_num_logits, dtype=torch.int32, device=device
)