[ModelRunner V2] Misc minor simplifications and optimizations (#33467)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-01 14:17:14 -08:00
committed by GitHub
parent 0b225fb7b2
commit e535d90deb
21 changed files with 86 additions and 220 deletions

View File

@@ -4,11 +4,7 @@
import numpy as np import numpy as np
import torch import torch
from vllm.v1.outputs import ( from vllm.v1.outputs import AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput
AsyncModelRunnerOutput,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput

View File

@@ -32,9 +32,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
def init_attn_backend( def init_attn_backend(
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
vllm_config: VllmConfig,
device: torch.device,
): ):
attn_backends: dict[str, type[AttentionBackend]] = {} attn_backends: dict[str, type[AttentionBackend]] = {}
attn_metadata_builders: list[AttentionMetadataBuilder] = [] attn_metadata_builders: list[AttentionMetadataBuilder] = []
@@ -50,10 +48,7 @@ def init_attn_backend(
attn_backends[layer_name] = attn_backend attn_backends[layer_name] = attn_backend
attn_metadata_builder = attn_backend.get_builder_cls()( attn_metadata_builder = attn_backend.get_builder_cls()(
kv_cache_group_spec.kv_cache_spec, kv_cache_group_spec.kv_cache_spec, layer_names, vllm_config, device
layer_names,
vllm_config,
device,
) )
attn_metadata_builders.append(attn_metadata_builder) # type: ignore attn_metadata_builders.append(attn_metadata_builder) # type: ignore
@@ -65,10 +60,7 @@ def init_attn_backend(
return attn_backends, attn_metadata_builders return attn_backends, attn_metadata_builders
def _allocate_kv_cache( def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
kv_cache_config: KVCacheConfig,
device: torch.device,
):
kv_cache_raw_tensors: dict[str, torch.Tensor] = {} kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors: for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device) tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
@@ -141,12 +133,11 @@ def init_kv_cache(
def build_slot_mappings_by_layer( def build_slot_mappings_by_layer(
slot_mappings: torch.Tensor, slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig
kv_cache_config: KVCacheConfig,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
slot_mappings_by_layer: dict[str, torch.Tensor] = {} slot_mappings_by_layer: dict[str, torch.Tensor] = {}
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): kv_cache_groups = kv_cache_config.kv_cache_groups
slot_mapping = slot_mappings[i] for slot_mapping, kv_cache_group in zip(slot_mappings, kv_cache_groups):
for layer_name in kv_cache_group.layer_names: for layer_name in kv_cache_group.layer_names:
slot_mappings_by_layer[layer_name] = slot_mapping slot_mappings_by_layer[layer_name] = slot_mapping
return slot_mappings_by_layer return slot_mappings_by_layer
@@ -188,8 +179,7 @@ def build_attn_metadata(
attn_metadata_builder = attn_metadata_builders[i] attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build( metadata = attn_metadata_builder.build(
common_prefix_len=0, common_prefix_len=0, common_attn_metadata=common_attn_metadata
common_attn_metadata=common_attn_metadata,
) )
for layer_name in kv_cache_spec.layer_names: for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata attn_metadata[layer_name] = metadata

View File

@@ -71,9 +71,7 @@ class BlockTables:
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses. # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
return torch.tensor( return torch.tensor(
[t.data_ptr() for t in x], [t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
dtype=torch.uint64,
device=self.device,
) )
def append_block_ids( def append_block_ids(
@@ -96,8 +94,7 @@ class BlockTables:
self.num_blocks.copy_to_uva() self.num_blocks.copy_to_uva()
def gather_block_tables( def gather_block_tables(
self, self, idx_mapping: torch.Tensor
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)]( _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from functools import partial
import numpy as np import numpy as np
import torch import torch
@@ -81,10 +82,7 @@ class UvaBufferPool:
class UvaBackedTensor: class UvaBackedTensor:
def __init__( def __init__(
self, self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
size: int | Sequence[int],
dtype: torch.dtype,
max_concurrency: int = 2,
): ):
self.dtype = dtype self.dtype = dtype
self.max_concurrency = max_concurrency self.max_concurrency = max_concurrency
@@ -135,25 +133,16 @@ class StagedWriteTensor:
self._staged_write_contents: list[int | float] = [] self._staged_write_contents: list[int | float] = []
self._staged_write_cu_lens: list[int] = [] self._staged_write_cu_lens: list[int] = []
self.write_indices = UvaBufferPool( new_buffer = partial(UvaBufferPool, max_concurrency=max_concurrency)
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
) self.write_indices = new_buffer(self.num_rows, dtype=torch.int32)
self.write_starts = UvaBufferPool( self.write_starts = new_buffer(self.num_rows, dtype=torch.int32)
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
init_size = next_power_of_2(self.num_rows) init_size = next_power_of_2(self.num_rows)
self.write_contents = UvaBufferPool( self.write_contents = new_buffer(init_size, dtype=dtype)
init_size, dtype=dtype, max_concurrency=max_concurrency self.write_cu_lens = new_buffer(self.num_rows, dtype=torch.int32)
)
self.write_cu_lens = UvaBufferPool(
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
)
def stage_write( def stage_write(
self, self, index: int, start: int, x: Iterable[int] | Iterable[float]
index: int,
start: int,
x: Iterable[int] | Iterable[float],
) -> None: ) -> None:
assert index >= 0 assert index >= 0
assert start >= 0 assert start >= 0

View File

@@ -24,12 +24,7 @@ from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager: class CudaGraphManager:
def __init__( def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device):
self,
vllm_config: VllmConfig,
uses_mrope: bool,
device: torch.device,
):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope self.uses_mrope = uses_mrope
@@ -41,11 +36,7 @@ class CudaGraphManager:
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None assert self.compilation_config is not None
self.cudagraph_mode: CUDAGraphMode self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes( self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes, self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs, self.max_num_reqs,

View File

@@ -13,10 +13,7 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
def get_batch_metadata_across_dp( def get_batch_metadata_across_dp(
num_tokens: int, num_tokens: int, cudagraph_size: int, dp_size: int, dp_rank: int
cudagraph_size: int,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert dp_size > 1 assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization. # Use CPU group to avoid CPU-GPU synchronization.
@@ -29,10 +26,7 @@ def get_batch_metadata_across_dp(
def get_cudagraph_and_dp_padding( def get_cudagraph_and_dp_padding(
num_tokens: int, num_tokens: int, cudagraph_size: int | None, dp_size: int, dp_rank: int
cudagraph_size: int | None,
dp_size: int,
dp_rank: int,
) -> tuple[bool, int, torch.Tensor | None]: ) -> tuple[bool, int, torch.Tensor | None]:
if dp_size == 1: if dp_size == 1:
if cudagraph_size is not None: if cudagraph_size is not None:

View File

@@ -65,10 +65,10 @@ class ActiveKVConnector(KVConnector):
if scheduler_output.preempted_req_ids: if scheduler_output.preempted_req_ids:
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids) self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
assert scheduler_output.kv_connector_metadata is not None kv_connector_metadata = scheduler_output.kv_connector_metadata
self.kv_connector.bind_connector_metadata( assert kv_connector_metadata is not None
scheduler_output.kv_connector_metadata self.kv_connector.bind_connector_metadata(kv_connector_metadata)
)
# TODO: sort out KV Connectors' use of forward_context # TODO: sort out KV Connectors' use of forward_context
if is_forward_context_available(): if is_forward_context_available():
self.kv_connector.start_load_kv(get_forward_context()) self.kv_connector.start_load_kv(get_forward_context())

View File

@@ -15,10 +15,7 @@ class LoraState:
self.lora_requests: dict[str, LoRARequest] = {} self.lora_requests: dict[str, LoRARequest] = {}
def add_request( def add_request(
self, self, req_id: str, req_index: int, lora_request: LoRARequest | None
req_id: str,
req_index: int,
lora_request: LoRARequest | None,
) -> None: ) -> None:
if lora_request is not None: if lora_request is not None:
self.lora_requests[req_id] = lora_request self.lora_requests[req_id] = lora_request
@@ -41,7 +38,7 @@ class LoraState:
active_lora_requests: set[LoRARequest] = set() active_lora_requests: set[LoRARequest] = set()
for req_id in req_ids: for req_id in req_ids:
lora_request = self.lora_requests.get(req_id, None) lora_request = self.lora_requests.get(req_id)
if lora_request is not None: if lora_request is not None:
active_lora_requests.add(lora_request) active_lora_requests.add(lora_request)
return prompt_lora_mapping, token_lora_mapping, active_lora_requests return prompt_lora_mapping, token_lora_mapping, active_lora_requests

View File

@@ -23,10 +23,7 @@ class EncoderRunner:
self.device = device self.device = device
self.inputs_embeds = torch.zeros( self.inputs_embeds = torch.zeros(
max_num_tokens, max_num_tokens, hidden_size, dtype=dtype, device=device
hidden_size,
dtype=dtype,
device=device,
) )
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {} self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {} self.encoder_cache: dict[str, torch.Tensor] = {}
@@ -57,8 +54,7 @@ class EncoderRunner:
self.req_id_to_mm_features.pop(req_id, None) self.req_id_to_mm_features.pop(req_id, None)
def prepare_mm_inputs( def prepare_mm_inputs(
self, self, scheduled_encoder_inputs: dict[str, list[int]]
scheduled_encoder_inputs: dict[str, list[int]],
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]: ) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
mm_hashes: list[str] = [] mm_hashes: list[str] = []
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = [] mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
@@ -85,20 +81,16 @@ class EncoderRunner:
encoder_outputs: list[torch.Tensor] = [] encoder_outputs: list[torch.Tensor] = []
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs, device=self.device, pin_memory=False
device=self.device,
pin_memory=False,
): ):
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
curr_group_outputs, curr_group_outputs, expected_num_items=num_items
expected_num_items=num_items,
) )
encoder_outputs.extend(curr_group_outputs) encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash # Cache the encoder outputs by mm_hash
for mm_hash, output in zip(mm_hashes, encoder_outputs): self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
self.encoder_cache[mm_hash] = output
return encoder_outputs return encoder_outputs
def gather_mm_embeddings( def gather_mm_embeddings(
@@ -115,9 +107,7 @@ class EncoderRunner:
if all_decode: if all_decode:
# All decode requests, so no need to gather any embeddings. # All decode requests, so no need to gather any embeddings.
return [], torch.zeros( return [], torch.zeros(
total_num_scheduled_tokens, total_num_scheduled_tokens, dtype=torch.bool, device=self.device
dtype=torch.bool,
device=self.device,
) )
query_start = computed_prefill_lens.tolist() query_start = computed_prefill_lens.tolist()
@@ -125,10 +115,7 @@ class EncoderRunner:
mm_embeds: list[torch.Tensor] = [] mm_embeds: list[torch.Tensor] = []
is_mm_embed = torch.zeros( is_mm_embed = torch.zeros(
total_num_scheduled_tokens, total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
dtype=torch.bool,
device="cpu",
pin_memory=True,
) )
for i, req_id in enumerate(req_ids): for i, req_id in enumerate(req_ids):
if not is_prefilling[i]: if not is_prefilling[i]:
@@ -189,9 +176,7 @@ class EncoderRunner:
is_mm_embed: torch.Tensor, is_mm_embed: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
x = model.embed_input_ids( x = model.embed_input_ids(
input_ids, input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
) )
# Copy to the pre-allocated buffer for CUDA graphs. # Copy to the pre-allocated buffer for CUDA graphs.
self.inputs_embeds[: x.shape[0]] = x self.inputs_embeds[: x.shape[0]] = x

View File

@@ -51,10 +51,7 @@ class MRopeState:
mm_features: list, mm_features: list,
) -> None: ) -> None:
prefill_mrope_positions, prefill_mrope_delta = ( prefill_mrope_positions, prefill_mrope_delta = (
mrope_model.get_mrope_input_positions( mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
prefill_token_ids,
mm_features,
)
) )
for i in range(3): for i in range(3):
pos = prefill_mrope_positions[i].tolist() pos = prefill_mrope_positions[i].tolist()

View File

@@ -339,10 +339,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
gc.collect() gc.collect()
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
self.encoder_runner.reset_mm_cache() if self.supports_mm_inputs:
self.encoder_runner.reset_mm_cache()
def reset_encoder_cache(self) -> None: def reset_encoder_cache(self) -> None:
self.encoder_runner.reset_encoder_cache() if self.supports_mm_inputs:
self.encoder_runner.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet. # SP is not supported yet.
@@ -402,10 +404,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def finish_requests(self, scheduler_output: SchedulerOutput) -> None: def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
finished_req_ids = scheduler_output.finished_req_ids finished_req_ids = scheduler_output.finished_req_ids
if scheduler_output.preempted_req_ids: preempted_req_ids = scheduler_output.preempted_req_ids
finished_req_ids = finished_req_ids.union( if preempted_req_ids:
scheduler_output.preempted_req_ids finished_req_ids = finished_req_ids.union(preempted_req_ids)
)
for req_id in finished_req_ids: for req_id in finished_req_ids:
self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
if self.supports_mm_inputs: if self.supports_mm_inputs:
@@ -477,28 +478,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
def prepare_inputs( def prepare_inputs(
self, self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> InputBatch: ) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0 assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens) num_tokens_per_req = scheduler_output.num_scheduled_tokens
num_reqs = len(num_tokens_per_req)
# Decode first, then prefill. # Decode first, then prefill.
# batch_idx -> req_id # batch_idx -> req_id
req_ids = sorted( req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get) # type: ignore[arg-type]
scheduler_output.num_scheduled_tokens.keys(), numtoks_iter = map(num_tokens_per_req.get, req_ids)
key=lambda k: scheduler_output.num_scheduled_tokens[k], num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
)
idx_mapping_list = [ idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
self.req_states.req_id_to_index[req_id] for req_id in req_ids idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
]
idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32)
idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device) idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
# Get the number of draft tokens for each request. # Get the number of draft tokens for each request.
@@ -889,8 +883,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode() @torch.inference_mode()
def sample_tokens( def sample_tokens(
self, self, grammar_output: GrammarOutput | None
grammar_output: GrammarOutput | None,
) -> AsyncOutput | ModelRunnerOutput: ) -> AsyncOutput | ModelRunnerOutput:
assert self.execute_model_state is not None assert self.execute_model_state is not None
hidden_states, input_batch, kv_connector_output = self.execute_model_state hidden_states, input_batch, kv_connector_output = self.execute_model_state

View File

@@ -13,11 +13,7 @@ MAX_NUM_STOP_TOKEN_IDS = 128
class LogitBiasState: class LogitBiasState:
def __init__( def __init__(self, max_num_reqs: int, device: torch.device):
self,
max_num_reqs: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
# Allowed token IDs. # Allowed token IDs.
@@ -54,10 +50,7 @@ class LogitBiasState:
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool) self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
def add_request( def add_request(
self, self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
req_idx: int,
prompt_len: int,
sampling_params: SamplingParams,
) -> None: ) -> None:
# Using any logit bias. # Using any logit bias.
use_logit_bias = False use_logit_bias = False

View File

@@ -73,19 +73,12 @@ def _ranks_kernel(
def compute_token_logprobs( def compute_token_logprobs(
logits: torch.Tensor, logits: torch.Tensor, token_ids: torch.Tensor
token_ids: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size = logits.shape[0] batch_size, vocab_size = logits.shape
vocab_size = logits.shape[1]
token_ids = token_ids.to(torch.int64) token_ids = token_ids.to(torch.int64)
num_logprobs = token_ids.shape[1] num_logprobs = token_ids.shape[1]
logprobs = torch.empty( logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
batch_size,
num_logprobs,
dtype=torch.float32,
device=logits.device,
)
_topk_log_softmax_kernel[(batch_size,)]( _topk_log_softmax_kernel[(batch_size,)](
logprobs, logprobs,
logits, logits,
@@ -107,23 +100,16 @@ def compute_topk_logprobs(
) -> LogprobsTensors: ) -> LogprobsTensors:
assert num_logprobs >= 0 assert num_logprobs >= 0
batch_size, vocab_size = logits.shape batch_size, vocab_size = logits.shape
if num_logprobs == 0: logprob_token_ids = sampled_token_ids.unsqueeze(-1)
logprob_token_ids = sampled_token_ids.unsqueeze(-1) if num_logprobs > 0:
else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat( logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of # logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens. # the topk + 1 tokens.
logprobs = compute_token_logprobs(logits, logprob_token_ids) logprobs = compute_token_logprobs(logits, logprob_token_ids)
token_ranks = torch.empty( token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
batch_size,
dtype=torch.int64,
device=logits.device,
)
_ranks_kernel[(batch_size,)]( _ranks_kernel[(batch_size,)](
token_ranks, token_ranks,
logits, logits,

View File

@@ -42,9 +42,7 @@ def _min_p_kernel(
def apply_min_p( def apply_min_p(
logits: torch.Tensor, logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
idx_mapping: torch.Tensor,
min_p: torch.Tensor,
) -> None: ) -> None:
num_reqs, vocab_size = logits.shape num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024

View File

@@ -23,11 +23,9 @@ class PromptLogprobsWorker:
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams): def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
# For now, only support prompt logprobs for the prompt tokens (not top-k). # For now, only support prompt logprobs for the prompt tokens (not top-k).
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
if uses_prompt_logprobs: if uses_prompt_logprobs:
self.uses_prompt_logprobs[req_idx] = True
self.in_progress_prompt_logprobs[req_id] = [] self.in_progress_prompt_logprobs[req_id] = []
else:
self.uses_prompt_logprobs[req_idx] = False
def remove_request(self, req_id: str) -> None: def remove_request(self, req_id: str) -> None:
self.in_progress_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs.pop(req_id, None)

View File

@@ -26,7 +26,7 @@ class Sampler:
device: torch.device, device: torch.device,
logprobs_mode: LogprobsMode = "raw_logprobs", logprobs_mode: LogprobsMode = "raw_logprobs",
): ):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default. self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
@@ -36,10 +36,7 @@ class Sampler:
self.logit_bias_state = LogitBiasState(max_num_reqs, device) self.logit_bias_state = LogitBiasState(max_num_reqs, device)
def add_request( def add_request(
self, self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
req_idx: int,
prompt_len: int,
sampling_params: SamplingParams,
) -> None: ) -> None:
self.sampling_states.add_request(req_idx, sampling_params) self.sampling_states.add_request(req_idx, sampling_params)
self.penalties_state.add_request(req_idx, sampling_params) self.penalties_state.add_request(req_idx, sampling_params)
@@ -74,11 +71,8 @@ class Sampler:
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS: if max_num_logprobs != NO_LOGPROBS:
logits = ( if self.logprobs_mode == "processed_logprobs":
processed_logits logits = processed_logits
if self.logprobs_mode == "processed_logprobs"
else logits
)
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
logprobs_tensors = compute_topk_logprobs( logprobs_tensors = compute_topk_logprobs(

View File

@@ -35,22 +35,19 @@ class SamplingStates:
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None: def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.temperature.np[req_idx] = sampling_params.temperature self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size: top_k = sampling_params.top_k
top_k = sampling_params.top_k if top_k <= 0 or top_k > self.vocab_size:
else:
top_k = self.vocab_size top_k = self.vocab_size
self.top_k.np[req_idx] = top_k self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p self.min_p.np[req_idx] = sampling_params.min_p
if sampling_params.seed is not None: seed = sampling_params.seed
seed = sampling_params.seed if seed is None:
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX) seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None: num_logprobs = sampling_params.logprobs
num_logprobs = sampling_params.logprobs if num_logprobs is None:
else:
num_logprobs = NO_LOGPROBS num_logprobs = NO_LOGPROBS
self.num_logprobs[req_idx] = num_logprobs self.num_logprobs[req_idx] = num_logprobs

View File

@@ -5,10 +5,7 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
def init_speculator( def init_speculator(vllm_config: VllmConfig, device: torch.device):
vllm_config: VllmConfig,
device: torch.device,
):
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
assert speculative_config is not None assert speculative_config is not None
if speculative_config.use_eagle(): if speculative_config.use_eagle():

View File

@@ -54,26 +54,15 @@ class EagleSpeculator:
device=device, device=device,
) )
self.hidden_states = torch.zeros( self.hidden_states = torch.zeros(
self.max_num_tokens, self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
self.hidden_size,
dtype=self.dtype,
device=device,
) )
self.idx_mapping = torch.zeros( self.idx_mapping = torch.zeros(
self.max_num_reqs, self.max_num_reqs, dtype=torch.int32, device=device
dtype=torch.int32,
device=device,
) )
self.temperature = torch.zeros( self.temperature = torch.zeros(
self.max_num_reqs, self.max_num_reqs, dtype=torch.float32, device=device
dtype=torch.float32,
device=device,
)
self.seeds = torch.zeros(
self.max_num_reqs,
dtype=torch.int64,
device=device,
) )
self.seeds = torch.zeros(self.max_num_reqs, dtype=torch.int64, device=device)
self.draft_tokens = torch.zeros( self.draft_tokens = torch.zeros(
self.max_num_reqs, self.max_num_reqs,
self.num_speculative_steps, self.num_speculative_steps,

View File

@@ -19,11 +19,7 @@ from vllm.v1.worker.gpu.input_batch import InputBuffers
class EagleCudaGraphManager: class EagleCudaGraphManager:
def __init__( def __init__(self, vllm_config: VllmConfig, device: torch.device):
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.device = device self.device = device
@@ -35,16 +31,10 @@ class EagleCudaGraphManager:
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None assert self.compilation_config is not None
cudagraph_mode: CUDAGraphMode self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.compilation_config.cudagraph_mode is None: if self.cudagraph_mode == CUDAGraphMode.FULL:
cudagraph_mode = CUDAGraphMode.NONE # NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
else: self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
cudagraph_mode = self.compilation_config.cudagraph_mode
if cudagraph_mode == CUDAGraphMode.FULL:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
self.cudagraph_mode = cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes( self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes, self.compilation_config.cudagraph_capture_sizes,

View File

@@ -10,12 +10,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch
class StructuredOutputsWorker: class StructuredOutputsWorker:
def __init__( def __init__(self, max_num_logits: int, vocab_size: int, device: torch.device):
self,
max_num_logits: int,
vocab_size: int,
device: torch.device,
):
self.logits_indices = torch.zeros( self.logits_indices = torch.zeros(
max_num_logits, dtype=torch.int32, device=device max_num_logits, dtype=torch.int32, device=device
) )