diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index a3aa54634..8d04848f8 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -36,12 +36,15 @@ FILES = [ "vllm/transformers_utils", "vllm/triton_utils", "vllm/usage", + "vllm/v1/core", + "vllm/v1/engine", ] # After fixing errors resulting from changing follow_imports # from "skip" to "silent", move the following directories to FILES SEPARATE_GROUPS = [ "tests", + # v0 related "vllm/attention", "vllm/compilation", "vllm/engine", @@ -50,7 +53,16 @@ SEPARATE_GROUPS = [ "vllm/model_executor", "vllm/plugins", "vllm/worker", - "vllm/v1", + # v1 related + "vllm/v1/attention", + "vllm/v1/executor", + "vllm/v1/kv_offload", + "vllm/v1/metrics", + "vllm/v1/pool", + "vllm/v1/sample", + "vllm/v1/spec_decode", + "vllm/v1/structured_output", + "vllm/v1/worker", ] # TODO(woosuk): Include the code from Megatron and HuggingFace. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f592a708a..1acac70c3 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -84,7 +84,9 @@ class VllmConfig: default_factory=StructuredOutputsConfig ) """Structured outputs configuration.""" - observability_config: ObservabilityConfig | None = None + observability_config: ObservabilityConfig = Field( + default_factory=ObservabilityConfig + ) """Observability configuration.""" quant_config: QuantizationConfig | None = None """Quantization configuration.""" @@ -170,10 +172,7 @@ class VllmConfig: vllm_factors.append(self.structured_outputs_config.compute_hash()) else: vllm_factors.append("None") - if self.observability_config: - vllm_factors.append(self.observability_config.compute_hash()) - else: - vllm_factors.append("None") + vllm_factors.append(self.observability_config.compute_hash()) if self.quant_config: pass # should be captured by model_config.quantization if self.compilation_config: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 959a03428..24fcd9fe1 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -77,6 +77,7 @@ class EngineClient(ABC): lora_request: LoRARequest | None = None, trace_headers: Mapping[str, str] | None = None, priority: int = 0, + truncate_prompt_tokens: int | None = None, tokenization_kwargs: dict[str, Any] | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model.""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c794886bc..ad6fbee2e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -167,7 +167,7 @@ class Scheduler(SchedulerInterface): self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - enable_caching=self.cache_config.enable_prefix_caching, + enable_caching=bool(self.cache_config.enable_prefix_caching), use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, @@ -407,13 +407,13 @@ class Scheduler(SchedulerInterface): # Get externally-cached tokens if using a KVConnector. if self.connector is not None: - num_external_computed_tokens, load_kv_async = ( + ext_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( request, num_new_local_computed_tokens ) ) - if num_external_computed_tokens is None: + if ext_tokens is None: # The request cannot be scheduled because # the KVConnector couldn't determine # the number of matched tokens. @@ -421,6 +421,8 @@ class Scheduler(SchedulerInterface): skipped_waiting_requests.prepend_request(request) continue + num_external_computed_tokens = ext_tokens + # Total computed tokens (local + external). num_computed_tokens = ( num_new_local_computed_tokens + num_external_computed_tokens @@ -905,13 +907,13 @@ class Scheduler(SchedulerInterface): outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None - kv_connector_stats = ( + kv_connector_stats: KVConnectorStats | None = ( kv_connector_output.kv_connector_stats if kv_connector_output else None ) if kv_connector_stats and self.connector: - stats = self.connector.get_kv_connector_stats() - if stats: - kv_connector_stats = kv_connector_stats.aggregate(stats) + kv_stats = self.connector.get_kv_connector_stats() + if kv_stats: + kv_connector_stats = kv_connector_stats.aggregate(kv_stats) failed_kv_load_req_ids = None if kv_connector_output and kv_connector_output.invalid_block_ids: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 761c37504..dc61d4501 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -6,7 +6,7 @@ import socket import time from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy -from typing import Any +from typing import Any, cast import numpy as np import torch @@ -131,10 +131,9 @@ class AsyncLLM(EngineClient): self.output_processor = OutputProcessor( self.tokenizer, log_stats=self.log_stats ) - if self.observability_config.otlp_traces_endpoint is not None: - tracer = init_tracer( - "vllm.llm_engine", self.observability_config.otlp_traces_endpoint - ) + endpoint = self.observability_config.otlp_traces_endpoint + if endpoint is not None: + tracer = init_tracer("vllm.llm_engine", endpoint) self.output_processor.tracer = tracer # EngineCore (starts the engine in background process). @@ -266,7 +265,9 @@ class AsyncLLM(EngineClient): if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() - cancel_task_threadsafe(getattr(self, "output_handler", None)) + handler = getattr(self, "output_handler", None) + if handler is not None: + cancel_task_threadsafe(handler) async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return await self.engine_core.get_supported_tasks_async() @@ -314,7 +315,10 @@ class AsyncLLM(EngineClient): priority, data_parallel_rank, ) - prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") + if isinstance(prompt, str): + prompt_text = prompt + elif isinstance(prompt, Mapping): + prompt_text = cast(str | None, prompt.get("prompt")) if is_pooling or params.n == 1: await self._add_request(request, prompt_text, None, 0, queue) @@ -436,6 +440,7 @@ class AsyncLLM(EngineClient): # Note: both OutputProcessor and EngineCore handle their # own request cleanup based on finished. finished = out.finished + assert isinstance(out, RequestOutput) yield out # If the request is disconnected by the client, generate() @@ -653,7 +658,7 @@ class AsyncLLM(EngineClient): return self.tokenizer async def is_tracing_enabled(self) -> bool: - return self.observability_config.otlp_traces_endpoint is not None + return self.observability_config.otlp_traces_endpoint is not None # type: ignore async def do_log_stats(self) -> None: if self.logger_manager: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 85cab32eb..6cbd986b3 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1075,6 +1075,7 @@ class DPEngineCoreProc(EngineCoreProc): local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local assert dp_size > 1 + assert local_dp_rank is not None assert 0 <= local_dp_rank <= dp_rank < dp_size if vllm_config.kv_transfer_config is not None: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 7b554ca99..9b440505b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -385,10 +385,11 @@ class BackgroundResources: with contextlib.suppress(Exception): task.cancel() - if in_loop(loop): - close_sockets_and_tasks() - elif loop and not loop.is_closed(): - loop.call_soon_threadsafe(close_sockets_and_tasks) + if loop is not None: + if in_loop(loop): + close_sockets_and_tasks() + elif not loop.is_closed(): + loop.call_soon_threadsafe(close_sockets_and_tasks) else: # Loop has been closed, try to clean up directly. del tasks @@ -1044,6 +1045,7 @@ class DPAsyncMPClient(AsyncMPClient): return assert self.stats_update_address is not None + stats_addr: str = self.stats_update_address assert len(self.engine_ranks_managed) > 0 # NOTE: running and waiting counts are all global from # the Coordinator include all global EngineCores. This @@ -1054,9 +1056,7 @@ class DPAsyncMPClient(AsyncMPClient): async def run_engine_stats_update_task(): with ( - make_zmq_socket( - self.ctx, self.stats_update_address, zmq.XSUB, linger=0 - ) as socket, + make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket, make_zmq_socket( self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0 ) as first_req_rcv_socket, diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 5f66e3689..b7a24096b 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -69,14 +69,21 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # Stop strings params = request.sampling_params assert params is not None - self.stop = stop = params.stop + stop_list: list[str] + if params.stop is None: + stop_list = [] + elif isinstance(params.stop, str): + stop_list = [params.stop] + else: + stop_list = params.stop + self.stop = stop_list self.min_tokens = params.min_tokens self.include_stop_str_in_output = params.include_stop_str_in_output # Number of chars to hold back when stop strings are to be excluded # from streamed output. - if stop and not self.include_stop_str_in_output: - self.stop_buffer_length = max(len(s) for s in stop) - 1 + if self.stop and not self.include_stop_str_in_output: + self.stop_buffer_length = max(len(s) for s in self.stop) - 1 else: self.stop_buffer_length = 0 self._last_output_text_offset: int = 0 diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 0fce34370..c2ca9579d 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable, Mapping from copy import copy -from typing import Any +from typing import Any, cast import torch.nn as nn from typing_extensions import TypeVar @@ -112,10 +112,9 @@ class LLMEngine: self.output_processor = OutputProcessor( self.tokenizer, log_stats=self.log_stats ) - if self.observability_config.otlp_traces_endpoint is not None: - tracer = init_tracer( - "vllm.llm_engine", self.observability_config.otlp_traces_endpoint - ) + endpoint = self.observability_config.otlp_traces_endpoint + if endpoint is not None: + tracer = init_tracer("vllm.llm_engine", endpoint) self.output_processor.tracer = tracer # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) @@ -259,7 +258,10 @@ class LLMEngine: trace_headers, priority, ) - prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") + if isinstance(prompt, str): + prompt_text = prompt + elif isinstance(prompt, Mapping): + prompt_text = cast(str | None, prompt.get("prompt")) n = params.n if isinstance(params, SamplingParams) else 1 @@ -285,7 +287,7 @@ class LLMEngine: # Add the request to EngineCore. self.engine_core.add_request(child_request) - def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]: + def step(self) -> list[RequestOutput | PoolingRequestOutput]: if self.should_execute_dummy_batch: self.should_execute_dummy_batch = False self.engine_core.execute_dummy_batch() diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 44e4eadce..07c8113dd 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -44,10 +44,16 @@ class RequestOutputCollector: if self.output is None or isinstance(output, Exception): self.output = output self.ready.set() - elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)): + elif isinstance(self.output, RequestOutput) and isinstance( + output, RequestOutput + ): # This ensures that request outputs with different request indexes # (if n > 1) do not override each other. self.output.add(output, aggregate=self.aggregate) + elif isinstance(self.output, PoolingRequestOutput) and isinstance( + output, PoolingRequestOutput + ): + self.output = output async def get(self) -> RequestOutput | PoolingRequestOutput: """Get operation blocks on put event.""" @@ -408,7 +414,7 @@ class OutputProcessor: within the loop below. """ - request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = [] + request_outputs: list[RequestOutput | PoolingRequestOutput] = [] reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 2a47befec..26ee10d2b 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import copy -from typing import Optional +from typing import Optional, cast from vllm.outputs import CompletionOutput from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -37,7 +37,7 @@ class ParentRequest: self.child_requests = set() self.output_aggregator = ( - [None] * sampling_params.n + [cast(CompletionOutput, None)] * sampling_params.n if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY) else [] ) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index de15677ae..c49fd1bde 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping -from typing import Any, Literal +from typing import Any, Literal, cast from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs @@ -208,9 +208,9 @@ class Processor: enc = prompt.get("encoder_prompt") dec = prompt.get("decoder_prompt") if enc is not None: - _validate_single_prompt(enc) + _validate_single_prompt(cast(dict | str, enc)) if dec is not None: - _validate_single_prompt(dec) + _validate_single_prompt(cast(dict | str, dec)) else: _validate_single_prompt(prompt) # type: ignore[arg-type] @@ -332,7 +332,7 @@ class Processor: if not mm_data: return None - mm_uuids: MultiModalUUIDDict = {} + mm_uuids: dict[str, list[str | None] | str] = {} for modality, data in mm_data.items(): n = len(data) if isinstance(data, list) else 1 mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] @@ -384,7 +384,9 @@ class Processor: # if provided. self._validate_multi_modal_uuids(prompt) if isinstance(prompt, dict): - mm_uuids = prompt.get("multi_modal_uuids") + mm_uuids = cast( + MultiModalUUIDDict | None, prompt.get("multi_modal_uuids") + ) else: mm_uuids = None @@ -410,20 +412,13 @@ class Processor: encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs) - # Mypy does not always properly infer the types of some elements of - # discriminated unions of TypedDicts, because of how it handles - # inheritance of TypedDict. If we explicitly extract the items we want - # we can avoid type errors from using `dict.get` later in the method. - prompt_token_ids = ( - decoder_inputs["prompt_token_ids"] - if decoder_inputs["type"] != "embeds" - else None - ) - prompt_embeds = ( - decoder_inputs["prompt_embeds"] - if decoder_inputs["type"] == "embeds" - else None - ) + # Mypy can be conservative for TypedDict unions; normalize access. + if decoder_inputs["type"] == "embeds": + prompt_token_ids = None + prompt_embeds = decoder_inputs["prompt_embeds"] + else: + prompt_token_ids = decoder_inputs["prompt_token_ids"] + prompt_embeds = None sampling_params = None pooling_params = None