[ModelRunner V2] Misc minor simplifications and optimizations (#33467)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -4,11 +4,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import (
|
||||
AsyncModelRunnerOutput,
|
||||
LogprobsTensors,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
|
||||
|
||||
|
||||
@@ -32,9 +32,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
|
||||
|
||||
def init_attn_backend(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
|
||||
):
|
||||
attn_backends: dict[str, type[AttentionBackend]] = {}
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
@@ -50,10 +48,7 @@ def init_attn_backend(
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
attn_metadata_builder = attn_backend.get_builder_cls()(
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
kv_cache_group_spec.kv_cache_spec, layer_names, vllm_config, device
|
||||
)
|
||||
attn_metadata_builders.append(attn_metadata_builder) # type: ignore
|
||||
|
||||
@@ -65,10 +60,7 @@ def init_attn_backend(
|
||||
return attn_backends, attn_metadata_builders
|
||||
|
||||
|
||||
def _allocate_kv_cache(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size, dtype=torch.int8, device=device)
|
||||
@@ -141,12 +133,11 @@ def init_kv_cache(
|
||||
|
||||
|
||||
def build_slot_mappings_by_layer(
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig
|
||||
) -> dict[str, torch.Tensor]:
|
||||
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
|
||||
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
slot_mapping = slot_mappings[i]
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
for slot_mapping, kv_cache_group in zip(slot_mappings, kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
slot_mappings_by_layer[layer_name] = slot_mapping
|
||||
return slot_mappings_by_layer
|
||||
@@ -188,8 +179,7 @@ def build_attn_metadata(
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
|
||||
@@ -71,9 +71,7 @@ class BlockTables:
|
||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
|
||||
return torch.tensor(
|
||||
[t.data_ptr() for t in x],
|
||||
dtype=torch.uint64,
|
||||
device=self.device,
|
||||
[t.data_ptr() for t in x], dtype=torch.uint64, device=self.device
|
||||
)
|
||||
|
||||
def append_block_ids(
|
||||
@@ -96,8 +94,7 @@ class BlockTables:
|
||||
self.num_blocks.copy_to_uva()
|
||||
|
||||
def gather_block_tables(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
self, idx_mapping: torch.Tensor
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Sequence
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -81,10 +82,7 @@ class UvaBufferPool:
|
||||
|
||||
class UvaBackedTensor:
|
||||
def __init__(
|
||||
self,
|
||||
size: int | Sequence[int],
|
||||
dtype: torch.dtype,
|
||||
max_concurrency: int = 2,
|
||||
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
|
||||
):
|
||||
self.dtype = dtype
|
||||
self.max_concurrency = max_concurrency
|
||||
@@ -135,25 +133,16 @@ class StagedWriteTensor:
|
||||
self._staged_write_contents: list[int | float] = []
|
||||
self._staged_write_cu_lens: list[int] = []
|
||||
|
||||
self.write_indices = UvaBufferPool(
|
||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
||||
)
|
||||
self.write_starts = UvaBufferPool(
|
||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
||||
)
|
||||
new_buffer = partial(UvaBufferPool, max_concurrency=max_concurrency)
|
||||
|
||||
self.write_indices = new_buffer(self.num_rows, dtype=torch.int32)
|
||||
self.write_starts = new_buffer(self.num_rows, dtype=torch.int32)
|
||||
init_size = next_power_of_2(self.num_rows)
|
||||
self.write_contents = UvaBufferPool(
|
||||
init_size, dtype=dtype, max_concurrency=max_concurrency
|
||||
)
|
||||
self.write_cu_lens = UvaBufferPool(
|
||||
self.num_rows, dtype=torch.int32, max_concurrency=max_concurrency
|
||||
)
|
||||
self.write_contents = new_buffer(init_size, dtype=dtype)
|
||||
self.write_cu_lens = new_buffer(self.num_rows, dtype=torch.int32)
|
||||
|
||||
def stage_write(
|
||||
self,
|
||||
index: int,
|
||||
start: int,
|
||||
x: Iterable[int] | Iterable[float],
|
||||
self, index: int, start: int, x: Iterable[int] | Iterable[float]
|
||||
) -> None:
|
||||
assert index >= 0
|
||||
assert start >= 0
|
||||
|
||||
@@ -24,12 +24,7 @@ from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
|
||||
class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
@@ -41,11 +36,7 @@ class CudaGraphManager:
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
self.cudagraph_mode: CUDAGraphMode
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
self.cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
self.max_num_reqs,
|
||||
|
||||
@@ -13,10 +13,7 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
|
||||
|
||||
|
||||
def get_batch_metadata_across_dp(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
num_tokens: int, cudagraph_size: int, dp_size: int, dp_rank: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert dp_size > 1
|
||||
# Use CPU group to avoid CPU-GPU synchronization.
|
||||
@@ -29,10 +26,7 @@ def get_batch_metadata_across_dp(
|
||||
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int | None,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
num_tokens: int, cudagraph_size: int | None, dp_size: int, dp_rank: int
|
||||
) -> tuple[bool, int, torch.Tensor | None]:
|
||||
if dp_size == 1:
|
||||
if cudagraph_size is not None:
|
||||
|
||||
@@ -65,10 +65,10 @@ class ActiveKVConnector(KVConnector):
|
||||
|
||||
if scheduler_output.preempted_req_ids:
|
||||
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
|
||||
assert scheduler_output.kv_connector_metadata is not None
|
||||
self.kv_connector.bind_connector_metadata(
|
||||
scheduler_output.kv_connector_metadata
|
||||
)
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
|
||||
# TODO: sort out KV Connectors' use of forward_context
|
||||
if is_forward_context_available():
|
||||
self.kv_connector.start_load_kv(get_forward_context())
|
||||
|
||||
@@ -15,10 +15,7 @@ class LoraState:
|
||||
self.lora_requests: dict[str, LoRARequest] = {}
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
req_index: int,
|
||||
lora_request: LoRARequest | None,
|
||||
self, req_id: str, req_index: int, lora_request: LoRARequest | None
|
||||
) -> None:
|
||||
if lora_request is not None:
|
||||
self.lora_requests[req_id] = lora_request
|
||||
@@ -41,7 +38,7 @@ class LoraState:
|
||||
|
||||
active_lora_requests: set[LoRARequest] = set()
|
||||
for req_id in req_ids:
|
||||
lora_request = self.lora_requests.get(req_id, None)
|
||||
lora_request = self.lora_requests.get(req_id)
|
||||
if lora_request is not None:
|
||||
active_lora_requests.add(lora_request)
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@@ -23,10 +23,7 @@ class EncoderRunner:
|
||||
self.device = device
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
max_num_tokens,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
max_num_tokens, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
@@ -57,8 +54,7 @@ class EncoderRunner:
|
||||
self.req_id_to_mm_features.pop(req_id, None)
|
||||
|
||||
def prepare_mm_inputs(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
self, scheduled_encoder_inputs: dict[str, list[int]]
|
||||
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
|
||||
mm_hashes: list[str] = []
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
|
||||
@@ -85,20 +81,16 @@ class EncoderRunner:
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=False,
|
||||
mm_kwargs, device=self.device, pin_memory=False
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs,
|
||||
expected_num_items=num_items,
|
||||
curr_group_outputs, expected_num_items=num_items
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
for mm_hash, output in zip(mm_hashes, encoder_outputs):
|
||||
self.encoder_cache[mm_hash] = output
|
||||
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
|
||||
return encoder_outputs
|
||||
|
||||
def gather_mm_embeddings(
|
||||
@@ -115,9 +107,7 @@ class EncoderRunner:
|
||||
if all_decode:
|
||||
# All decode requests, so no need to gather any embeddings.
|
||||
return [], torch.zeros(
|
||||
total_num_scheduled_tokens,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
total_num_scheduled_tokens, dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
query_start = computed_prefill_lens.tolist()
|
||||
@@ -125,10 +115,7 @@ class EncoderRunner:
|
||||
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
is_mm_embed = torch.zeros(
|
||||
total_num_scheduled_tokens,
|
||||
dtype=torch.bool,
|
||||
device="cpu",
|
||||
pin_memory=True,
|
||||
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
|
||||
)
|
||||
for i, req_id in enumerate(req_ids):
|
||||
if not is_prefilling[i]:
|
||||
@@ -189,9 +176,7 @@ class EncoderRunner:
|
||||
is_mm_embed: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = model.embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||
)
|
||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||
self.inputs_embeds[: x.shape[0]] = x
|
||||
|
||||
@@ -51,10 +51,7 @@ class MRopeState:
|
||||
mm_features: list,
|
||||
) -> None:
|
||||
prefill_mrope_positions, prefill_mrope_delta = (
|
||||
mrope_model.get_mrope_input_positions(
|
||||
prefill_token_ids,
|
||||
mm_features,
|
||||
)
|
||||
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
|
||||
)
|
||||
for i in range(3):
|
||||
pos = prefill_mrope_positions[i].tolist()
|
||||
|
||||
@@ -339,10 +339,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
gc.collect()
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.encoder_runner.reset_mm_cache()
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.reset_mm_cache()
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
self.encoder_runner.reset_encoder_cache()
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.reset_encoder_cache()
|
||||
|
||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||
# SP is not supported yet.
|
||||
@@ -402,10 +404,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
finished_req_ids = scheduler_output.finished_req_ids
|
||||
if scheduler_output.preempted_req_ids:
|
||||
finished_req_ids = finished_req_ids.union(
|
||||
scheduler_output.preempted_req_ids
|
||||
)
|
||||
preempted_req_ids = scheduler_output.preempted_req_ids
|
||||
if preempted_req_ids:
|
||||
finished_req_ids = finished_req_ids.union(preempted_req_ids)
|
||||
for req_id in finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if self.supports_mm_inputs:
|
||||
@@ -477,28 +478,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
num_tokens_after_padding: int,
|
||||
self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
|
||||
) -> InputBatch:
|
||||
num_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert num_tokens > 0
|
||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
||||
num_tokens_per_req = scheduler_output.num_scheduled_tokens
|
||||
num_reqs = len(num_tokens_per_req)
|
||||
|
||||
# Decode first, then prefill.
|
||||
# batch_idx -> req_id
|
||||
req_ids = sorted(
|
||||
scheduler_output.num_scheduled_tokens.keys(),
|
||||
key=lambda k: scheduler_output.num_scheduled_tokens[k],
|
||||
)
|
||||
num_scheduled_tokens = np.array(
|
||||
[scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
|
||||
)
|
||||
req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get) # type: ignore[arg-type]
|
||||
numtoks_iter = map(num_tokens_per_req.get, req_ids)
|
||||
num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
|
||||
|
||||
idx_mapping_list = [
|
||||
self.req_states.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
idx_mapping_np = np.array(idx_mapping_list, dtype=np.int32)
|
||||
idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
|
||||
idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
|
||||
idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
|
||||
|
||||
# Get the number of draft tokens for each request.
|
||||
@@ -889,8 +883,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_tokens(
|
||||
self,
|
||||
grammar_output: GrammarOutput | None,
|
||||
self, grammar_output: GrammarOutput | None
|
||||
) -> AsyncOutput | ModelRunnerOutput:
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, input_batch, kv_connector_output = self.execute_model_state
|
||||
|
||||
@@ -13,11 +13,7 @@ MAX_NUM_STOP_TOKEN_IDS = 128
|
||||
|
||||
|
||||
class LogitBiasState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
device: torch.device,
|
||||
):
|
||||
def __init__(self, max_num_reqs: int, device: torch.device):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
|
||||
# Allowed token IDs.
|
||||
@@ -54,10 +50,7 @@ class LogitBiasState:
|
||||
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_idx: int,
|
||||
prompt_len: int,
|
||||
sampling_params: SamplingParams,
|
||||
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||
) -> None:
|
||||
# Using any logit bias.
|
||||
use_logit_bias = False
|
||||
|
||||
@@ -73,19 +73,12 @@ def _ranks_kernel(
|
||||
|
||||
|
||||
def compute_token_logprobs(
|
||||
logits: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
logits: torch.Tensor, token_ids: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
batch_size = logits.shape[0]
|
||||
vocab_size = logits.shape[1]
|
||||
batch_size, vocab_size = logits.shape
|
||||
token_ids = token_ids.to(torch.int64)
|
||||
num_logprobs = token_ids.shape[1]
|
||||
logprobs = torch.empty(
|
||||
batch_size,
|
||||
num_logprobs,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32)
|
||||
_topk_log_softmax_kernel[(batch_size,)](
|
||||
logprobs,
|
||||
logits,
|
||||
@@ -107,23 +100,16 @@ def compute_topk_logprobs(
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
if num_logprobs == 0:
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
else:
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
if num_logprobs > 0:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat(
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
||||
)
|
||||
logprob_token_ids = torch.cat((logprob_token_ids, topk_indices), dim=1)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||
# the topk + 1 tokens.
|
||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||
token_ranks = torch.empty(
|
||||
batch_size,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
token_ranks = torch.empty(batch_size, dtype=torch.int64, device=logits.device)
|
||||
_ranks_kernel[(batch_size,)](
|
||||
token_ranks,
|
||||
logits,
|
||||
|
||||
@@ -42,9 +42,7 @@ def _min_p_kernel(
|
||||
|
||||
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
@@ -23,11 +23,9 @@ class PromptLogprobsWorker:
|
||||
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
|
||||
# For now, only support prompt logprobs for the prompt tokens (not top-k).
|
||||
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||
self.uses_prompt_logprobs[req_idx] = uses_prompt_logprobs
|
||||
if uses_prompt_logprobs:
|
||||
self.uses_prompt_logprobs[req_idx] = True
|
||||
self.in_progress_prompt_logprobs[req_id] = []
|
||||
else:
|
||||
self.uses_prompt_logprobs[req_idx] = False
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.in_progress_prompt_logprobs.pop(req_id, None)
|
||||
|
||||
@@ -26,7 +26,7 @@ class Sampler:
|
||||
device: torch.device,
|
||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||
):
|
||||
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
||||
if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
|
||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||
self.logprobs_mode = logprobs_mode
|
||||
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
|
||||
@@ -36,10 +36,7 @@ class Sampler:
|
||||
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_idx: int,
|
||||
prompt_len: int,
|
||||
sampling_params: SamplingParams,
|
||||
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
|
||||
) -> None:
|
||||
self.sampling_states.add_request(req_idx, sampling_params)
|
||||
self.penalties_state.add_request(req_idx, sampling_params)
|
||||
@@ -74,11 +71,8 @@ class Sampler:
|
||||
|
||||
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
|
||||
if max_num_logprobs != NO_LOGPROBS:
|
||||
logits = (
|
||||
processed_logits
|
||||
if self.logprobs_mode == "processed_logprobs"
|
||||
else logits
|
||||
)
|
||||
if self.logprobs_mode == "processed_logprobs":
|
||||
logits = processed_logits
|
||||
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
|
||||
cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None
|
||||
logprobs_tensors = compute_topk_logprobs(
|
||||
|
||||
@@ -35,22 +35,19 @@ class SamplingStates:
|
||||
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
if 0 < sampling_params.top_k < self.vocab_size:
|
||||
top_k = sampling_params.top_k
|
||||
else:
|
||||
top_k = sampling_params.top_k
|
||||
if top_k <= 0 or top_k > self.vocab_size:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
self.min_p.np[req_idx] = sampling_params.min_p
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seed = sampling_params.seed
|
||||
else:
|
||||
seed = sampling_params.seed
|
||||
if seed is None:
|
||||
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
|
||||
self.seeds.np[req_idx] = seed
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
num_logprobs = sampling_params.logprobs
|
||||
else:
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = NO_LOGPROBS
|
||||
self.num_logprobs[req_idx] = num_logprobs
|
||||
|
||||
|
||||
@@ -5,10 +5,7 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
def init_speculator(
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
def init_speculator(vllm_config: VllmConfig, device: torch.device):
|
||||
speculative_config = vllm_config.speculative_config
|
||||
assert speculative_config is not None
|
||||
if speculative_config.use_eagle():
|
||||
|
||||
@@ -54,26 +54,15 @@ class EagleSpeculator:
|
||||
device=device,
|
||||
)
|
||||
self.hidden_states = torch.zeros(
|
||||
self.max_num_tokens,
|
||||
self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
|
||||
)
|
||||
self.idx_mapping = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
self.temperature = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.seeds = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
self.max_num_reqs, dtype=torch.float32, device=device
|
||||
)
|
||||
self.seeds = torch.zeros(self.max_num_reqs, dtype=torch.int64, device=device)
|
||||
self.draft_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
|
||||
@@ -19,11 +19,7 @@ from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
|
||||
class EagleCudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device = device
|
||||
@@ -35,16 +31,10 @@ class EagleCudaGraphManager:
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
cudagraph_mode: CUDAGraphMode
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
||||
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
|
||||
self.cudagraph_mode = cudagraph_mode
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
if self.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
|
||||
self.cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
|
||||
@@ -10,12 +10,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
|
||||
|
||||
class StructuredOutputsWorker:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_logits: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
):
|
||||
def __init__(self, max_num_logits: int, vocab_size: int, device: torch.device):
|
||||
self.logits_indices = torch.zeros(
|
||||
max_num_logits, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user