[ModelRunner V2] Misc minor simplifications and optimizations (#33467)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -4,11 +4,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.v1.outputs import (
|
from vllm.v1.outputs import AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput
|
||||||
AsyncModelRunnerOutput,
|
|
||||||
LogprobsTensors,
|
|
||||||
ModelRunnerOutput,
|
|
||||||
)
|
|
||||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
|||||||
|
|
||||||
|
|
||||||
def init_attn_backend(
|
def init_attn_backend(
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
):
|
||||||
attn_backends: dict[str, type[AttentionBackend]] = {}
|
attn_backends: dict[str, type[AttentionBackend]] = {}
|
||||||
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||||
@@ -50,10 +48,7 @@ def init_attn_backend(
|
|||||||
attn_backends[layer_name] = attn_backend
|
attn_backends[layer_name] = attn_backend
|
||||||
|
|
||||||
attn_metadata_builder = attn_backend.get_builder_cls()(
|
attn_metadata_builder = attn_backend.get_builder_cls()(
|
||||||
kv_cache_group_spec.kv_cache_spec,
|
kv_cache_group_spec.kv_cache_spec, layer_names, vllm_config, device
|
||||||
layer_names,
|
|
||||||
vllm_config,
|
|
||||||
device,
|
|
||||||
)
|
)
|
||||||
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
|
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
|
||||||
|
|
||||||
@@ -65,10 +60,7 @@ def init_attn_backend(
|
|||||||
return attn_backends, attn_metadata_builders
|
return attn_backends, attn_metadata_builders
|
||||||
|
|
||||||
|
|
||||||
def _allocate_kv_cache(
|
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
|
||||||
kv_cache_config: KVCacheConfig,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||||
@@ -141,12 +133,11 @@ def init_kv_cache(
|
|||||||
|
|
||||||
|
|
||||||
def build_slot_mappings_by_layer(
|
def build_slot_mappings_by_layer(
|
||||||
slot_mappings: torch.Tensor,
|
slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig
|
||||||
kv_cache_config: KVCacheConfig,
|
|
||||||
) -> dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
|
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
|
||||||
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||||
slot_mapping = slot_mappings[i]
|
for slot_mapping, kv_cache_group in zip(slot_mappings, kv_cache_groups):
|
||||||
for layer_name in kv_cache_group.layer_names:
|
for layer_name in kv_cache_group.layer_names:
|
||||||
slot_mappings_by_layer[layer_name] = slot_mapping
|
slot_mappings_by_layer[layer_name] = slot_mapping
|
||||||
return slot_mappings_by_layer
|
return slot_mappings_by_layer
|
||||||
@@ -188,8 +179,7 @@ def build_attn_metadata(
|
|||||||
|
|
||||||
attn_metadata_builder = attn_metadata_builders[i]
|
attn_metadata_builder = attn_metadata_builders[i]
|
||||||
metadata = attn_metadata_builder.build(
|
metadata = attn_metadata_builder.build(
|
||||||
common_prefix_len=0,
|
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
)
|
)
|
||||||
for layer_name in kv_cache_spec.layer_names:
|
for layer_name in kv_cache_spec.layer_names:
|
||||||
attn_metadata[layer_name] = metadata
|
attn_metadata[layer_name] = metadata
|
||||||
|
|||||||
@@ -71,9 +71,7 @@ class BlockTables:
|
|||||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||||
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
|
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
|
||||||
return torch.tensor(
|
return torch.tensor(
|
||||||
[t.data_ptr() for t in x],
|
[t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
|
||||||
dtype=torch.uint64,
|
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def append_block_ids(
|
def append_block_ids(
|
||||||
@@ -96,8 +94,7 @@ class BlockTables:
|
|||||||
self.num_blocks.copy_to_uva()
|
self.num_blocks.copy_to_uva()
|
||||||
|
|
||||||
def gather_block_tables(
|
def gather_block_tables(
|
||||||
self,
|
self, idx_mapping: torch.Tensor
|
||||||
idx_mapping: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
num_reqs = idx_mapping.shape[0]
|
num_reqs = idx_mapping.shape[0]
|
||||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -81,10 +82,7 @@ class UvaBufferPool:
|
|||||||
|
|
||||||
class UvaBackedTensor:
|
class UvaBackedTensor:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
|
||||||
size: int | Sequence[int],
|
|
||||||
dtype: torch.dtype,
|
|
||||||
max_concurrency: int = 2,
|
|
||||||
):
|
):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.max_concurrency = max_concurrency
|
self.max_concurrency = max_concurrency
|
||||||
@@ -135,25 +133,16 @@ class StagedWriteTensor:
|
|||||||
self._staged_write_contents: list[int | float] = []
|
self._staged_write_contents: list[int | float] = []
|
||||||
self._staged_write_cu_lens: list[int] = []
|
self._staged_write_cu_lens: list[int] = []
|
||||||
|
|
||||||
self.write_indices = UvaBufferPool(
|
new_buffer = partial(UvaBufferPool, max_concurrency=max_concurrency)
|
||||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
|
||||||
)
|
self.write_indices = new_buffer(self.num_rows, dtype=torch.int32)
|
||||||
self.write_starts = UvaBufferPool(
|
self.write_starts = new_buffer(self.num_rows, dtype=torch.int32)
|
||||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
|
||||||
)
|
|
||||||
init_size = next_power_of_2(self.num_rows)
|
init_size = next_power_of_2(self.num_rows)
|
||||||
self.write_contents = UvaBufferPool(
|
self.write_contents = new_buffer(init_size, dtype=dtype)
|
||||||
init_size, dtype=dtype, max_concurrency=max_concurrency
|
self.write_cu_lens = new_buffer(self.num_rows, dtype=torch.int32)
|
||||||
)
|
|
||||||
self.write_cu_lens = UvaBufferPool(
|
|
||||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
|
||||||
)
|
|
||||||
|
|
||||||
def stage_write(
|
def stage_write(
|
||||||
self,
|
self, index: int, start: int, x: Iterable[int] | Iterable[float]
|
||||||
index: int,
|
|
||||||
start: int,
|
|
||||||
x: Iterable[int] | Iterable[float],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
assert index >= 0
|
assert index >= 0
|
||||||
assert start >= 0
|
assert start >= 0
|
||||||
|
|||||||
@@ -24,12 +24,7 @@ from vllm.v1.worker.gpu.input_batch import InputBuffers
|
|||||||
|
|
||||||
|
|
||||||
class CudaGraphManager:
|
class CudaGraphManager:
|
||||||
def __init__(
|
def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device):
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
uses_mrope: bool,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.uses_mrope = uses_mrope
|
self.uses_mrope = uses_mrope
|
||||||
@@ -41,10 +36,6 @@ class CudaGraphManager:
|
|||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
assert self.compilation_config is not None
|
assert self.compilation_config is not None
|
||||||
self.cudagraph_mode: CUDAGraphMode
|
|
||||||
if self.compilation_config.cudagraph_mode is None:
|
|
||||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
|
||||||
else:
|
|
||||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
self.cudagraph_sizes = get_cudagraph_sizes(
|
self.cudagraph_sizes = get_cudagraph_sizes(
|
||||||
self.compilation_config.cudagraph_capture_sizes,
|
self.compilation_config.cudagraph_capture_sizes,
|
||||||
|
|||||||
@@ -13,10 +13,7 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
|
|||||||
|
|
||||||
|
|
||||||
def get_batch_metadata_across_dp(
|
def get_batch_metadata_across_dp(
|
||||||
num_tokens: int,
|
num_tokens: int, cudagraph_size: int, dp_size: int, dp_rank: int
|
||||||
cudagraph_size: int,
|
|
||||||
dp_size: int,
|
|
||||||
dp_rank: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert dp_size > 1
|
assert dp_size > 1
|
||||||
# Use CPU group to avoid CPU-GPU synchronization.
|
# Use CPU group to avoid CPU-GPU synchronization.
|
||||||
@@ -29,10 +26,7 @@ def get_batch_metadata_across_dp(
|
|||||||
|
|
||||||
|
|
||||||
def get_cudagraph_and_dp_padding(
|
def get_cudagraph_and_dp_padding(
|
||||||
num_tokens: int,
|
num_tokens: int, cudagraph_size: int | None, dp_size: int, dp_rank: int
|
||||||
cudagraph_size: int | None,
|
|
||||||
dp_size: int,
|
|
||||||
dp_rank: int,
|
|
||||||
) -> tuple[bool, int, torch.Tensor | None]:
|
) -> tuple[bool, int, torch.Tensor | None]:
|
||||||
if dp_size == 1:
|
if dp_size == 1:
|
||||||
if cudagraph_size is not None:
|
if cudagraph_size is not None:
|
||||||
|
|||||||
@@ -65,10 +65,10 @@ class ActiveKVConnector(KVConnector):
|
|||||||
|
|
||||||
if scheduler_output.preempted_req_ids:
|
if scheduler_output.preempted_req_ids:
|
||||||
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
|
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
|
||||||
assert scheduler_output.kv_connector_metadata is not None
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
self.kv_connector.bind_connector_metadata(
|
assert kv_connector_metadata is not None
|
||||||
scheduler_output.kv_connector_metadata
|
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
|
||||||
)
|
|
||||||
# TODO: sort out KV Connectors' use of forward_context
|
# TODO: sort out KV Connectors' use of forward_context
|
||||||
if is_forward_context_available():
|
if is_forward_context_available():
|
||||||
self.kv_connector.start_load_kv(get_forward_context())
|
self.kv_connector.start_load_kv(get_forward_context())
|
||||||
|
|||||||
@@ -15,10 +15,7 @@ class LoraState:
|
|||||||
self.lora_requests: dict[str, LoRARequest] = {}
|
self.lora_requests: dict[str, LoRARequest] = {}
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self, req_id: str, req_index: int, lora_request: LoRARequest | None
|
||||||
req_id: str,
|
|
||||||
req_index: int,
|
|
||||||
lora_request: LoRARequest | None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if lora_request is not None:
|
if lora_request is not None:
|
||||||
self.lora_requests[req_id] = lora_request
|
self.lora_requests[req_id] = lora_request
|
||||||
@@ -41,7 +38,7 @@ class LoraState:
|
|||||||
|
|
||||||
active_lora_requests: set[LoRARequest] = set()
|
active_lora_requests: set[LoRARequest] = set()
|
||||||
for req_id in req_ids:
|
for req_id in req_ids:
|
||||||
lora_request = self.lora_requests.get(req_id, None)
|
lora_request = self.lora_requests.get(req_id)
|
||||||
if lora_request is not None:
|
if lora_request is not None:
|
||||||
active_lora_requests.add(lora_request)
|
active_lora_requests.add(lora_request)
|
||||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||||
|
|||||||
@@ -23,10 +23,7 @@ class EncoderRunner:
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.inputs_embeds = torch.zeros(
|
self.inputs_embeds = torch.zeros(
|
||||||
max_num_tokens,
|
max_num_tokens, hidden_size, dtype=dtype, device=device
|
||||||
hidden_size,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||||
@@ -57,8 +54,7 @@ class EncoderRunner:
|
|||||||
self.req_id_to_mm_features.pop(req_id, None)
|
self.req_id_to_mm_features.pop(req_id, None)
|
||||||
|
|
||||||
def prepare_mm_inputs(
|
def prepare_mm_inputs(
|
||||||
self,
|
self, scheduled_encoder_inputs: dict[str, list[int]]
|
||||||
scheduled_encoder_inputs: dict[str, list[int]],
|
|
||||||
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
|
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
|
||||||
mm_hashes: list[str] = []
|
mm_hashes: list[str] = []
|
||||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
|
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
|
||||||
@@ -85,20 +81,16 @@ class EncoderRunner:
|
|||||||
|
|
||||||
encoder_outputs: list[torch.Tensor] = []
|
encoder_outputs: list[torch.Tensor] = []
|
||||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||||
mm_kwargs,
|
mm_kwargs, device=self.device, pin_memory=False
|
||||||
device=self.device,
|
|
||||||
pin_memory=False,
|
|
||||||
):
|
):
|
||||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||||
sanity_check_mm_encoder_outputs(
|
sanity_check_mm_encoder_outputs(
|
||||||
curr_group_outputs,
|
curr_group_outputs, expected_num_items=num_items
|
||||||
expected_num_items=num_items,
|
|
||||||
)
|
)
|
||||||
encoder_outputs.extend(curr_group_outputs)
|
encoder_outputs.extend(curr_group_outputs)
|
||||||
|
|
||||||
# Cache the encoder outputs by mm_hash
|
# Cache the encoder outputs by mm_hash
|
||||||
for mm_hash, output in zip(mm_hashes, encoder_outputs):
|
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
|
||||||
self.encoder_cache[mm_hash] = output
|
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
def gather_mm_embeddings(
|
def gather_mm_embeddings(
|
||||||
@@ -115,9 +107,7 @@ class EncoderRunner:
|
|||||||
if all_decode:
|
if all_decode:
|
||||||
# All decode requests, so no need to gather any embeddings.
|
# All decode requests, so no need to gather any embeddings.
|
||||||
return [], torch.zeros(
|
return [], torch.zeros(
|
||||||
total_num_scheduled_tokens,
|
total_num_scheduled_tokens, dtype=torch.bool, device=self.device
|
||||||
dtype=torch.bool,
|
|
||||||
device=self.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
query_start = computed_prefill_lens.tolist()
|
query_start = computed_prefill_lens.tolist()
|
||||||
@@ -125,10 +115,7 @@ class EncoderRunner:
|
|||||||
|
|
||||||
mm_embeds: list[torch.Tensor] = []
|
mm_embeds: list[torch.Tensor] = []
|
||||||
is_mm_embed = torch.zeros(
|
is_mm_embed = torch.zeros(
|
||||||
total_num_scheduled_tokens,
|
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
|
||||||
dtype=torch.bool,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=True,
|
|
||||||
)
|
)
|
||||||
for i, req_id in enumerate(req_ids):
|
for i, req_id in enumerate(req_ids):
|
||||||
if not is_prefilling[i]:
|
if not is_prefilling[i]:
|
||||||
@@ -189,9 +176,7 @@ class EncoderRunner:
|
|||||||
is_mm_embed: torch.Tensor,
|
is_mm_embed: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = model.embed_input_ids(
|
x = model.embed_input_ids(
|
||||||
input_ids,
|
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||||
multimodal_embeddings=mm_embeds,
|
|
||||||
is_multimodal=is_mm_embed,
|
|
||||||
)
|
)
|
||||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||||
self.inputs_embeds[: x.shape[0]] = x
|
self.inputs_embeds[: x.shape[0]] = x
|
||||||
|
|||||||
@@ -51,10 +51,7 @@ class MRopeState:
|
|||||||
mm_features: list,
|
mm_features: list,
|
||||||
) -> None:
|
) -> None:
|
||||||
prefill_mrope_positions, prefill_mrope_delta = (
|
prefill_mrope_positions, prefill_mrope_delta = (
|
||||||
mrope_model.get_mrope_input_positions(
|
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
|
||||||
prefill_token_ids,
|
|
||||||
mm_features,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
pos = prefill_mrope_positions[i].tolist()
|
pos = prefill_mrope_positions[i].tolist()
|
||||||
|
|||||||
@@ -339,9 +339,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def reset_mm_cache(self) -> None:
|
def reset_mm_cache(self) -> None:
|
||||||
|
if self.supports_mm_inputs:
|
||||||
self.encoder_runner.reset_mm_cache()
|
self.encoder_runner.reset_mm_cache()
|
||||||
|
|
||||||
def reset_encoder_cache(self) -> None:
|
def reset_encoder_cache(self) -> None:
|
||||||
|
if self.supports_mm_inputs:
|
||||||
self.encoder_runner.reset_encoder_cache()
|
self.encoder_runner.reset_encoder_cache()
|
||||||
|
|
||||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||||
@@ -402,10 +404,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
|
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||||
finished_req_ids = scheduler_output.finished_req_ids
|
finished_req_ids = scheduler_output.finished_req_ids
|
||||||
if scheduler_output.preempted_req_ids:
|
preempted_req_ids = scheduler_output.preempted_req_ids
|
||||||
finished_req_ids = finished_req_ids.union(
|
if preempted_req_ids:
|
||||||
scheduler_output.preempted_req_ids
|
finished_req_ids = finished_req_ids.union(preempted_req_ids)
|
||||||
)
|
|
||||||
for req_id in finished_req_ids:
|
for req_id in finished_req_ids:
|
||||||
self.req_states.remove_request(req_id)
|
self.req_states.remove_request(req_id)
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
@@ -477,28 +478,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs(
|
def prepare_inputs(
|
||||||
self,
|
self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
|
||||||
scheduler_output: SchedulerOutput,
|
|
||||||
num_tokens_after_padding: int,
|
|
||||||
) -> InputBatch:
|
) -> InputBatch:
|
||||||
num_tokens = scheduler_output.total_num_scheduled_tokens
|
num_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert num_tokens > 0
|
assert num_tokens > 0
|
||||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
num_tokens_per_req = scheduler_output.num_scheduled_tokens
|
||||||
|
num_reqs = len(num_tokens_per_req)
|
||||||
|
|
||||||
# Decode first, then prefill.
|
# Decode first, then prefill.
|
||||||
# batch_idx -> req_id
|
# batch_idx -> req_id
|
||||||
req_ids = sorted(
|
req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get) # type: ignore[arg-type]
|
||||||
scheduler_output.num_scheduled_tokens.keys(),
|
numtoks_iter = map(num_tokens_per_req.get, req_ids)
|
||||||
key=lambda k: scheduler_output.num_scheduled_tokens[k],
|
num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
|
||||||
)
|
|
||||||
num_scheduled_tokens = np.array(
|
|
||||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
idx_mapping_list = [
|
idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
|
||||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
|
||||||
]
|
|
||||||
idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32)
|
|
||||||
idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
|
idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
|
||||||
|
|
||||||
# Get the number of draft tokens for each request.
|
# Get the number of draft tokens for each request.
|
||||||
@@ -889,8 +883,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample_tokens(
|
def sample_tokens(
|
||||||
self,
|
self, grammar_output: GrammarOutput | None
|
||||||
grammar_output: GrammarOutput | None,
|
|
||||||
) -> AsyncOutput | ModelRunnerOutput:
|
) -> AsyncOutput | ModelRunnerOutput:
|
||||||
assert self.execute_model_state is not None
|
assert self.execute_model_state is not None
|
||||||
hidden_states, input_batch, kv_connector_output = self.execute_model_state
|
hidden_states, input_batch, kv_connector_output = self.execute_model_state
|
||||||
|
|||||||
@@ -13,11 +13,7 @@ MAX_NUM_STOP_TOKEN_IDS = 128
|
|||||||
|
|
||||||
|
|
||||||
class LogitBiasState:
|
class LogitBiasState:
|
||||||
def __init__(
|
def __init__(self, max_num_reqs: int, device: torch.device):
|
||||||
self,
|
|
||||||
max_num_reqs: int,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
self.max_num_reqs = max_num_reqs
|
self.max_num_reqs = max_num_reqs
|
||||||
|
|
||||||
# Allowed token IDs.
|
# Allowed token IDs.
|
||||||
@@ -54,10 +50,7 @@ class LogitBiasState:
|
|||||||
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
|
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||||
req_idx: int,
|
|
||||||
prompt_len: int,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
# Using any logit bias.
|
# Using any logit bias.
|
||||||
use_logit_bias = False
|
use_logit_bias = False
|
||||||
|
|||||||
@@ -73,19 +73,12 @@ def _ranks_kernel(
|
|||||||
|
|
||||||
|
|
||||||
def compute_token_logprobs(
|
def compute_token_logprobs(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor, token_ids: torch.Tensor
|
||||||
token_ids: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch_size = logits.shape[0]
|
batch_size, vocab_size = logits.shape
|
||||||
vocab_size = logits.shape[1]
|
|
||||||
token_ids = token_ids.to(torch.int64)
|
token_ids = token_ids.to(torch.int64)
|
||||||
num_logprobs = token_ids.shape[1]
|
num_logprobs = token_ids.shape[1]
|
||||||
logprobs = torch.empty(
|
logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
|
||||||
batch_size,
|
|
||||||
num_logprobs,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
_topk_log_softmax_kernel[(batch_size,)](
|
_topk_log_softmax_kernel[(batch_size,)](
|
||||||
logprobs,
|
logprobs,
|
||||||
logits,
|
logits,
|
||||||
@@ -107,23 +100,16 @@ def compute_topk_logprobs(
|
|||||||
) -> LogprobsTensors:
|
) -> LogprobsTensors:
|
||||||
assert num_logprobs >= 0
|
assert num_logprobs >= 0
|
||||||
batch_size, vocab_size = logits.shape
|
batch_size, vocab_size = logits.shape
|
||||||
if num_logprobs == 0:
|
|
||||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||||
else:
|
if num_logprobs > 0:
|
||||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||||
logprob_token_ids = torch.cat(
|
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
|
||||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||||
# the topk + 1 tokens.
|
# the topk + 1 tokens.
|
||||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||||
token_ranks = torch.empty(
|
token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
|
||||||
batch_size,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
_ranks_kernel[(batch_size,)](
|
_ranks_kernel[(batch_size,)](
|
||||||
token_ranks,
|
token_ranks,
|
||||||
logits,
|
logits,
|
||||||
|
|||||||
@@ -42,9 +42,7 @@ def _min_p_kernel(
|
|||||||
|
|
||||||
|
|
||||||
def apply_min_p(
|
def apply_min_p(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||||
idx_mapping: torch.Tensor,
|
|
||||||
min_p: torch.Tensor,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
num_reqs, vocab_size = logits.shape
|
num_reqs, vocab_size = logits.shape
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
|
|||||||
@@ -23,11 +23,9 @@ class PromptLogprobsWorker:
|
|||||||
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
|
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
|
||||||
# For now, only support prompt logprobs for the prompt tokens (not top-k).
|
# For now, only support prompt logprobs for the prompt tokens (not top-k).
|
||||||
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||||
|
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
|
||||||
if uses_prompt_logprobs:
|
if uses_prompt_logprobs:
|
||||||
self.uses_prompt_logprobs[req_idx] = True
|
|
||||||
self.in_progress_prompt_logprobs[req_id] = []
|
self.in_progress_prompt_logprobs[req_id] = []
|
||||||
else:
|
|
||||||
self.uses_prompt_logprobs[req_idx] = False
|
|
||||||
|
|
||||||
def remove_request(self, req_id: str) -> None:
|
def remove_request(self, req_id: str) -> None:
|
||||||
self.in_progress_prompt_logprobs.pop(req_id, None)
|
self.in_progress_prompt_logprobs.pop(req_id, None)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class Sampler:
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||||
):
|
):
|
||||||
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
|
||||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||||
self.logprobs_mode = logprobs_mode
|
self.logprobs_mode = logprobs_mode
|
||||||
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
|
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
|
||||||
@@ -36,10 +36,7 @@ class Sampler:
|
|||||||
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
|
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
|
||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||||
req_idx: int,
|
|
||||||
prompt_len: int,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.sampling_states.add_request(req_idx, sampling_params)
|
self.sampling_states.add_request(req_idx, sampling_params)
|
||||||
self.penalties_state.add_request(req_idx, sampling_params)
|
self.penalties_state.add_request(req_idx, sampling_params)
|
||||||
@@ -74,11 +71,8 @@ class Sampler:
|
|||||||
|
|
||||||
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
|
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
|
||||||
if max_num_logprobs != NO_LOGPROBS:
|
if max_num_logprobs != NO_LOGPROBS:
|
||||||
logits = (
|
if self.logprobs_mode == "processed_logprobs":
|
||||||
processed_logits
|
logits = processed_logits
|
||||||
if self.logprobs_mode == "processed_logprobs"
|
|
||||||
else logits
|
|
||||||
)
|
|
||||||
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
|
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
|
||||||
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
|
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
|
||||||
logprobs_tensors = compute_topk_logprobs(
|
logprobs_tensors = compute_topk_logprobs(
|
||||||
|
|||||||
@@ -35,22 +35,19 @@ class SamplingStates:
|
|||||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||||
self.temperature.np[req_idx] = sampling_params.temperature
|
self.temperature.np[req_idx] = sampling_params.temperature
|
||||||
self.top_p.np[req_idx] = sampling_params.top_p
|
self.top_p.np[req_idx] = sampling_params.top_p
|
||||||
if 0 < sampling_params.top_k < self.vocab_size:
|
|
||||||
top_k = sampling_params.top_k
|
top_k = sampling_params.top_k
|
||||||
else:
|
if top_k <= 0 or top_k > self.vocab_size:
|
||||||
top_k = self.vocab_size
|
top_k = self.vocab_size
|
||||||
self.top_k.np[req_idx] = top_k
|
self.top_k.np[req_idx] = top_k
|
||||||
self.min_p.np[req_idx] = sampling_params.min_p
|
self.min_p.np[req_idx] = sampling_params.min_p
|
||||||
|
|
||||||
if sampling_params.seed is not None:
|
|
||||||
seed = sampling_params.seed
|
seed = sampling_params.seed
|
||||||
else:
|
if seed is None:
|
||||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||||
self.seeds.np[req_idx] = seed
|
self.seeds.np[req_idx] = seed
|
||||||
|
|
||||||
if sampling_params.logprobs is not None:
|
|
||||||
num_logprobs = sampling_params.logprobs
|
num_logprobs = sampling_params.logprobs
|
||||||
else:
|
if num_logprobs is None:
|
||||||
num_logprobs = NO_LOGPROBS
|
num_logprobs = NO_LOGPROBS
|
||||||
self.num_logprobs[req_idx] = num_logprobs
|
self.num_logprobs[req_idx] = num_logprobs
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ import torch
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
|
||||||
def init_speculator(
|
def init_speculator(vllm_config: VllmConfig, device: torch.device):
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
speculative_config = vllm_config.speculative_config
|
speculative_config = vllm_config.speculative_config
|
||||||
assert speculative_config is not None
|
assert speculative_config is not None
|
||||||
if speculative_config.use_eagle():
|
if speculative_config.use_eagle():
|
||||||
|
|||||||
@@ -54,26 +54,15 @@ class EagleSpeculator:
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
self.hidden_states = torch.zeros(
|
self.hidden_states = torch.zeros(
|
||||||
self.max_num_tokens,
|
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
|
||||||
self.hidden_size,
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
self.idx_mapping = torch.zeros(
|
self.idx_mapping = torch.zeros(
|
||||||
self.max_num_reqs,
|
self.max_num_reqs, dtype=torch.int32, device=device
|
||||||
dtype=torch.int32,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
self.temperature = torch.zeros(
|
self.temperature = torch.zeros(
|
||||||
self.max_num_reqs,
|
self.max_num_reqs, dtype=torch.float32, device=device
|
||||||
dtype=torch.float32,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
self.seeds = torch.zeros(
|
|
||||||
self.max_num_reqs,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
|
self.seeds = torch.zeros(self.max_num_reqs, dtype=torch.int64, device=device)
|
||||||
self.draft_tokens = torch.zeros(
|
self.draft_tokens = torch.zeros(
|
||||||
self.max_num_reqs,
|
self.max_num_reqs,
|
||||||
self.num_speculative_steps,
|
self.num_speculative_steps,
|
||||||
|
|||||||
@@ -19,11 +19,7 @@ from vllm.v1.worker.gpu.input_batch import InputBuffers
|
|||||||
|
|
||||||
|
|
||||||
class EagleCudaGraphManager:
|
class EagleCudaGraphManager:
|
||||||
def __init__(
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||||
self,
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -35,16 +31,10 @@ class EagleCudaGraphManager:
|
|||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
assert self.compilation_config is not None
|
assert self.compilation_config is not None
|
||||||
|
|
||||||
cudagraph_mode: CUDAGraphMode
|
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
if self.compilation_config.cudagraph_mode is None:
|
if self.cudagraph_mode == CUDAGraphMode.FULL:
|
||||||
cudagraph_mode = CUDAGraphMode.NONE
|
|
||||||
else:
|
|
||||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
|
||||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
|
||||||
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
||||||
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||||
|
|
||||||
self.cudagraph_mode = cudagraph_mode
|
|
||||||
|
|
||||||
self.cudagraph_sizes = get_cudagraph_sizes(
|
self.cudagraph_sizes = get_cudagraph_sizes(
|
||||||
self.compilation_config.cudagraph_capture_sizes,
|
self.compilation_config.cudagraph_capture_sizes,
|
||||||
|
|||||||
@@ -10,12 +10,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch
|
|||||||
|
|
||||||
|
|
||||||
class StructuredOutputsWorker:
|
class StructuredOutputsWorker:
|
||||||
def __init__(
|
def __init__(self, max_num_logits: int, vocab_size: int, device: torch.device):
|
||||||
self,
|
|
||||||
max_num_logits: int,
|
|
||||||
vocab_size: int,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
self.logits_indices = torch.zeros(
|
self.logits_indices = torch.zeros(
|
||||||
max_num_logits, dtype=torch.int32, device=device
|
max_num_logits, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user