[Misc] Replace Optional[X] with X | None syntax (#33332)
Signed-off-by: carlory <baofa.fan@daocloud.io> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -27,7 +27,7 @@ class BeamSearchSequence:
|
||||
text: str | None = None
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None
|
||||
multi_modal_data: "MultiModalDataDict | None" = None
|
||||
mm_processor_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||
KVConnectorBase,
|
||||
@@ -44,7 +44,7 @@ class KVConnectorFactory:
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
) -> KVConnectorBase:
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
if kv_transfer_config is None:
|
||||
|
||||
@@ -41,7 +41,7 @@ The class provides the following primitives:
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
@@ -161,7 +161,7 @@ class KVConnectorBase_V1(ABC):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
logger.warning(
|
||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||
@@ -383,13 +383,13 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
||||
def get_kv_connector_stats(self) -> "KVConnectorStats | None":
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> Optional["KVConnectorKVEvents"]:
|
||||
def get_kv_connector_kv_cache_events(self) -> "KVConnectorKVEvents | None":
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
This function should be called by the model runner every time after the
|
||||
@@ -558,7 +558,7 @@ class KVConnectorBase_V1(ABC):
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> Optional["KVConnectorStats"]:
|
||||
) -> "KVConnectorStats | None":
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
@@ -584,7 +584,7 @@ class KVConnectorBase_V1(ABC):
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> Optional["KVConnectorPromMetrics"]:
|
||||
) -> "KVConnectorPromMetrics | None":
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
|
||||
@@ -32,7 +32,7 @@ Usage:
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -84,7 +84,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -91,7 +91,7 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from lmcache import utils
|
||||
@@ -274,7 +274,7 @@ class ReqMeta:
|
||||
load_spec: LoadSpec | None = None,
|
||||
discard_partial_chunks: bool = True,
|
||||
save_decode_cache: bool = False,
|
||||
) -> Optional["ReqMeta"]:
|
||||
) -> "ReqMeta | None":
|
||||
"""Create the request metadata from a request tracker.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import enum
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
@@ -385,7 +385,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
@@ -595,7 +595,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
self.worker_adapter.shutdown()
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
||||
def get_kv_connector_stats(self) -> "KVConnectorStats | None":
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
@@ -810,7 +810,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> Optional["KVConnectorStats"]:
|
||||
) -> "KVConnectorStats | None":
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
@@ -825,7 +825,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> Optional["KVConnectorPromMetrics"]:
|
||||
) -> "KVConnectorPromMetrics | None":
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
|
||||
@@ -6,7 +6,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
@@ -115,7 +115,7 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@@ -101,7 +101,7 @@ class MoRIIOAgentMetadata(
|
||||
class RoleManager:
|
||||
"""Manages role state across the connector."""
|
||||
|
||||
_instance: Optional["RoleManager"] = None
|
||||
_instance: "RoleManager | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
||||
@@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import msgpack
|
||||
import msgspec
|
||||
@@ -90,7 +90,7 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role)
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
@@ -333,7 +333,7 @@ class MoRIIOConnectorScheduler:
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
connector_worker: Optional["MoRIIOConnectorWorker"] = None,
|
||||
connector_worker: "MoRIIOConnectorWorker | None" = None,
|
||||
):
|
||||
params = request.kv_transfer_params
|
||||
if not params:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from weakref import ref as weakref_ref
|
||||
|
||||
import msgpack
|
||||
@@ -340,7 +340,7 @@ class MoRIIOWrapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moriio_engine: Optional["IOEngine"] = None,
|
||||
moriio_engine: "IOEngine | None" = None,
|
||||
tp_rank: int = 0,
|
||||
dp_rank: int = 0,
|
||||
):
|
||||
|
||||
@@ -14,7 +14,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
@@ -302,7 +302,7 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -76,7 +76,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
kv_cache_config: "KVCacheConfig | None" = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
@@ -49,7 +49,7 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
|
||||
|
||||
|
||||
def ensure_kv_transfer_initialized(
|
||||
vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None
|
||||
vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig | None" = None
|
||||
) -> None:
|
||||
"""
|
||||
Initialize KV cache transfer parallel group.
|
||||
|
||||
@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -106,7 +106,7 @@ def _get_unique_name(name: str) -> str:
|
||||
return newname
|
||||
|
||||
|
||||
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
|
||||
_groups: dict[str, Callable[[], "GroupCoordinator | None"]] = {}
|
||||
|
||||
|
||||
def _register_group(group: "GroupCoordinator") -> None:
|
||||
@@ -784,7 +784,7 @@ class GroupCoordinator:
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int | None = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_group: "GroupCoordinator | None" = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
"""Send the input tensor dictionary.
|
||||
@@ -871,7 +871,7 @@ class GroupCoordinator:
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int | None = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_group: "GroupCoordinator | None" = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
"""Recv the input tensor dictionary.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Pydantic models for Anthropic API protocol"""
|
||||
|
||||
import time
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
@@ -135,7 +135,7 @@ class AnthropicStreamEvent(BaseModel):
|
||||
"ping",
|
||||
"error",
|
||||
]
|
||||
message: Optional["AnthropicMessagesResponse"] = None
|
||||
message: "AnthropicMessagesResponse | None" = None
|
||||
delta: AnthropicDelta | None = None
|
||||
content_block: AnthropicContentBlock | None = None
|
||||
index: int | None = None
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
@@ -126,7 +125,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
|
||||
@classmethod
|
||||
def pack(
|
||||
cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
|
||||
cls, loras: GenericSequence["LoRALayerWeights | None"]
|
||||
) -> "PackedLoRALayerWeights":
|
||||
"""Pack a list of LoRAs into a single LoRA.
|
||||
|
||||
@@ -155,7 +154,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
@classmethod
|
||||
def pack_moe(
|
||||
cls,
|
||||
loras: GenericSequence[Optional["LoRALayerWeights"]],
|
||||
loras: GenericSequence["LoRALayerWeights | None"],
|
||||
module_name: str,
|
||||
is_non_gated_moe: bool = False,
|
||||
) -> "PackedLoRALayerWeights":
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
|
||||
@@ -131,7 +131,7 @@ def replace_submodule(
|
||||
|
||||
|
||||
def parse_fine_tuned_lora_name(
|
||||
name: str, weights_mapper: Optional["WeightsMapper"] = None
|
||||
name: str, weights_mapper: "WeightsMapper | None" = None
|
||||
) -> tuple[str, bool]:
|
||||
"""Parse the name of lora weights.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -284,7 +284,7 @@ class FusedMoEQuantConfig:
|
||||
return self._w1.bias
|
||||
|
||||
@property
|
||||
def w1_precision(self) -> Optional["PrecisionConfig"]:
|
||||
def w1_precision(self) -> "PrecisionConfig | None":
|
||||
assert self._w1.scale is None or isinstance(self._w1.scale, PrecisionConfig)
|
||||
return self._w1.scale
|
||||
|
||||
@@ -306,7 +306,7 @@ class FusedMoEQuantConfig:
|
||||
return self._w2.bias
|
||||
|
||||
@property
|
||||
def w2_precision(self) -> Optional["PrecisionConfig"]:
|
||||
def w2_precision(self) -> "PrecisionConfig | None":
|
||||
assert self._w2.scale is None or isinstance(self._w2.scale, PrecisionConfig)
|
||||
return self._w2.scale
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@@ -148,7 +148,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional["QuantizationMethods"]:
|
||||
) -> "QuantizationMethods | None":
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
|
||||
@@ -173,7 +173,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (
|
||||
@@ -160,7 +160,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
# collect schemes
|
||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
@@ -691,7 +691,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: str | None = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
) -> "CompressedTensorsScheme | None":
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@@ -105,7 +105,7 @@ class CPUAWQConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional["QuantizationMethods"]:
|
||||
) -> "QuantizationMethods | None":
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
if current_platform.is_cpu() and (quant_method == "awq"):
|
||||
return cls.get_name()
|
||||
@@ -113,7 +113,7 @@ class CPUAWQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -52,7 +52,7 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -78,7 +78,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -182,7 +182,7 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
def get_xpu_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
from vllm.model_executor.layers.quantization.ipex_quant import (
|
||||
XPUFp8LinearMethod,
|
||||
XPUFp8MoEMethod,
|
||||
@@ -218,7 +218,7 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if current_platform.is_xpu():
|
||||
return self.get_xpu_quant_method(layer, prefix)
|
||||
if isinstance(layer, LinearBase):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
@@ -77,7 +77,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_gguf(
|
||||
prefix, self.unquantized_modules, self.packed_modules_mapping
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@@ -240,7 +240,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -456,7 +456,7 @@ class INCConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> Optional["QuantizationMethods"]:
|
||||
) -> "QuantizationMethods | None":
|
||||
"""Override the `auto-round` method to `inc`."""
|
||||
is_auto_round_format = hf_quant_cfg.get("quant_method", None) == "auto-round"
|
||||
if is_auto_round_format:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -144,7 +144,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["LinearMethodBase"]:
|
||||
) -> "LinearMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.method == "awq":
|
||||
if is_layer_skipped(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fnmatch import fnmatch
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -181,7 +181,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
# handle kv-cache first so we can focus only on weight quantization thereafter
|
||||
if isinstance(layer, Attention):
|
||||
return self.KVCacheMethodCls(self)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -163,7 +163,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||
if isinstance(layer, FusedMoE):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -197,7 +196,7 @@ class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -159,7 +159,7 @@ class PetitNvFp4Config(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
exclude = self.require_exclude_modules()
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -67,7 +67,7 @@ class PTPCFp8Config(Fp8Config):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import fnmatch
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import torch
|
||||
|
||||
@@ -102,7 +102,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(
|
||||
|
||||
@@ -4,7 +4,7 @@ import importlib
|
||||
import json
|
||||
import types
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -209,7 +209,7 @@ class TorchAOConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
) -> "QuantizeMethodBase | None":
|
||||
if not isinstance(layer, LinearBase):
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
# 1. Create a global variable as a placeholder for the module
|
||||
_petit_kernel: Optional["ModuleType"] = None
|
||||
_petit_kernel: "ModuleType | None" = None
|
||||
|
||||
_PETIT_INSTALL_MSG = (
|
||||
"Petit is not installed. Please install it with `pip install petit-kernel`."
|
||||
|
||||
@@ -12,7 +12,7 @@ import threading
|
||||
import time
|
||||
from collections.abc import Generator, MutableMapping
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -323,7 +323,7 @@ class TensorizerConfig(MutableMapping):
|
||||
" is unstable and may lead to errors."
|
||||
)
|
||||
|
||||
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
|
||||
def open_stream(self, tensorizer_args: "TensorizerArgs | None" = None):
|
||||
if tensorizer_args is None:
|
||||
tensorizer_args = self._construct_tensorizer_args()
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
Union,
|
||||
@@ -186,7 +185,7 @@ class PlaceholderRange:
|
||||
length: int
|
||||
"""The length of the placeholder."""
|
||||
|
||||
is_embed: Optional["torch.Tensor"] = None
|
||||
is_embed: "torch.Tensor | None" = None
|
||||
"""
|
||||
A boolean mask of shape `(length,)` indicating which positions
|
||||
between `offset` and `offset + length` to assign embeddings to.
|
||||
@@ -341,7 +340,7 @@ class MultiModalFeatureSpec:
|
||||
`MultiModalFeatureSpec` per item.
|
||||
"""
|
||||
|
||||
data: Optional["MultiModalKwargsItem"]
|
||||
data: "MultiModalKwargsItem | None"
|
||||
"""
|
||||
Represents multimodal data for this feature.
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ pynvml. However, it should not initialize cuda context.
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache, wraps
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
@@ -382,7 +382,7 @@ class CudaPlatformBase(Platform):
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
backend: "AttentionBackendEnum | None" = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
|
||||
@@ -7,7 +7,7 @@ import platform
|
||||
import random
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -243,7 +243,7 @@ class Platform:
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
backend: "AttentionBackendEnum | None" = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
"""
|
||||
Get the vision attention backend class of a device.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import os
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -356,7 +356,7 @@ class RocmPlatform(Platform):
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
backend: "AttentionBackendEnum | None" = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -88,7 +88,7 @@ class XPUPlatform(Platform):
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: Optional["AttentionBackendEnum"] = None,
|
||||
backend: "AttentionBackendEnum | None" = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Annotated, Any, Optional
|
||||
from typing import Annotated, Any
|
||||
|
||||
import msgspec
|
||||
|
||||
@@ -80,7 +80,7 @@ class PoolingParams(
|
||||
return deepcopy(self)
|
||||
|
||||
def verify(
|
||||
self, task: PoolingTask, model_config: Optional["ModelConfig"] = None
|
||||
self, task: PoolingTask, model_config: "ModelConfig | None" = None
|
||||
) -> None:
|
||||
if self.task is None:
|
||||
self.task = task
|
||||
@@ -106,7 +106,7 @@ class PoolingParams(
|
||||
self._verify_valid_parameters()
|
||||
|
||||
def _merge_default_parameters(
|
||||
self, model_config: Optional["ModelConfig"] = None
|
||||
self, model_config: "ModelConfig | None" = None
|
||||
) -> None:
|
||||
if model_config is None:
|
||||
return
|
||||
@@ -160,7 +160,7 @@ class PoolingParams(
|
||||
if getattr(self, k, None) is None:
|
||||
setattr(self, k, getattr(pooler_config, k))
|
||||
|
||||
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
|
||||
def _set_default_parameters(self, model_config: "ModelConfig | None"):
|
||||
if self.task in ["embed", "token_embed"]:
|
||||
if self.use_activation is None:
|
||||
self.use_activation = True
|
||||
|
||||
@@ -5,7 +5,7 @@ import copy
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Optional, TypeAlias
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
|
||||
from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
|
||||
@@ -31,7 +31,7 @@ except ImportError:
|
||||
@dataclass
|
||||
class _ModuleTreeNode:
|
||||
event: _ProfilerEvent
|
||||
parent: Optional["_ModuleTreeNode"] = None
|
||||
parent: "_ModuleTreeNode | None" = None
|
||||
children: list["_ModuleTreeNode"] = field(default_factory=list)
|
||||
trace: str = ""
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import fnmatch
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
@@ -32,7 +32,7 @@ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
||||
|
||||
|
||||
def glob(
|
||||
s3: Optional["BaseClient"] = None,
|
||||
s3: "BaseClient | None" = None,
|
||||
path: str = "",
|
||||
allow_pattern: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -707,7 +707,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -284,7 +284,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
indexer: "Indexer | None" = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
@@ -87,11 +87,11 @@ class TreeAttentionMetadata:
|
||||
tree_attn_bias: torch.Tensor | None = None
|
||||
|
||||
# Cached Prefill/decode metadata.
|
||||
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
_cached_prefill_metadata: "TreeAttentionMetadata | None" = None
|
||||
_cached_decode_metadata: "TreeAttentionMetadata | None" = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
def prefill_metadata(self) -> "TreeAttentionMetadata | None":
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
@@ -116,7 +116,7 @@ class TreeAttentionMetadata:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
def decode_metadata(self) -> "TreeAttentionMetadata | None":
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
|
||||
@@ -189,7 +189,7 @@ class SchedulerInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self) -> Optional["SchedulerStats"]:
|
||||
def make_stats(self) -> "SchedulerStats | None":
|
||||
"""Make a SchedulerStats object for logging.
|
||||
|
||||
The SchedulerStats object is created for every scheduling step.
|
||||
@@ -201,5 +201,5 @@ class SchedulerInterface(ABC):
|
||||
"""Shutdown the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
|
||||
def get_kv_connector(self) -> "KVConnectorBase_V1 | None":
|
||||
return None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import copy
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
@@ -133,7 +133,7 @@ class ParentRequest:
|
||||
|
||||
@staticmethod
|
||||
def observe_finished_request(
|
||||
parent_req: Optional["ParentRequest"],
|
||||
parent_req: "ParentRequest | None",
|
||||
iteration_stats: IterationStats,
|
||||
num_generation_tokens: int,
|
||||
):
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections import deque
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -68,7 +68,7 @@ class Request:
|
||||
arrival_time: float | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
mm_features: list[MultiModalFeatureSpec] | None = None,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
lora_request: "LoRARequest | None" = None,
|
||||
cache_salt: str | None = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
|
||||
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -94,7 +94,7 @@ class LogitsProcessor(ABC):
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional["BatchUpdate"],
|
||||
batch_update: "BatchUpdate | None",
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
@@ -14,7 +14,6 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
@@ -229,7 +228,7 @@ def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
engine_manager: Union["CoreEngineProcManager", "CoreEngineActorManager"]
|
||||
| None = None,
|
||||
coordinator: Optional["DPCoordinator"] = None,
|
||||
coordinator: "DPCoordinator | None" = None,
|
||||
) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -15,7 +14,7 @@ logger = init_logger(__name__)
|
||||
_THREAD_ID_TO_CONTEXT: dict = {}
|
||||
# Here we hardcode the number of microbatches to 2 for default.
|
||||
_NUM_UBATCHES: int = 2
|
||||
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = []
|
||||
_CURRENT_CONTEXTS: list["UBatchContext | None"] = []
|
||||
|
||||
|
||||
class UBatchContext:
|
||||
|
||||
@@ -5,7 +5,6 @@ import inspect
|
||||
import os
|
||||
from itertools import accumulate
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,7 +25,7 @@ _MB = 1024**2
|
||||
_GiB = 1024**3
|
||||
|
||||
# Global workspace manager instance
|
||||
_manager: Optional["WorkspaceManager"] = None
|
||||
_manager: "WorkspaceManager | None" = None
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
|
||||
Reference in New Issue
Block a user