Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -87,8 +87,8 @@ class MiniMaxText01RMSNormTP(CustomOp):
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
@@ -102,7 +102,7 @@ class MiniMaxText01LinearKernel:
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: Optional[int] = None,
layer_idx: int | None = None,
**kwargs,
) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
@@ -154,9 +154,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
max_position: int,
block_size: int,
num_hidden_layer: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
prefix: str = "linear_attn",

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -68,8 +68,8 @@ class MambaMixer(MambaBase, CustomOp):
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -410,7 +410,7 @@ class MambaMixer(MambaBase, CustomOp):
return Mamba1AttentionBackend
def _time_proj_bias(self) -> Optional[torch.Tensor]:
def _time_proj_bias(self) -> torch.Tensor | None:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
return None
@@ -423,8 +423,8 @@ class PrefillDecodeSplit(NamedTuple):
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
query_start_loc_p: Optional[torch.Tensor]
has_initial_states_p: Optional[torch.Tensor]
query_start_loc_p: torch.Tensor | None
has_initial_states_p: torch.Tensor | None
def split_batch_to_prefill_and_decode(
@@ -432,7 +432,7 @@ def split_batch_to_prefill_and_decode(
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
query_start_loc: torch.Tensor,
has_initial_states: Optional[torch.Tensor],
has_initial_states: torch.Tensor | None,
num_prefill_tokens: int,
num_decode_tokens: int,
num_prefills: int,

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -138,7 +138,7 @@ class Mixer2RMSNormGated(CustomOp):
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
@@ -244,9 +244,9 @@ class MambaMixer2(MambaBase, CustomOp):
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -474,7 +474,7 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
):
pass
@@ -482,7 +482,7 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
):
torch.ops.vllm.mamba_mixer2(
hidden_states,
@@ -495,7 +495,7 @@ class MambaMixer2(MambaBase, CustomOp):
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
):
forward_context = get_forward_context()
# attn_metadata contains metadata necessary for the mamba2 triton
@@ -904,7 +904,7 @@ def mamba_mixer2(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
@@ -915,7 +915,7 @@ def mamba_mixer2_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
mup_vector: torch.Tensor | None = None,
) -> None:
return

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Union
import torch
@@ -14,7 +13,7 @@ class MambaStateDtypeCalculator:
@classmethod
def linear_attention_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires testing
@@ -26,7 +25,7 @@ class MambaStateDtypeCalculator:
@classmethod
def mamba1_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
@@ -37,7 +36,7 @@ class MambaStateDtypeCalculator:
@classmethod
def mamba2_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
@@ -48,7 +47,7 @@ class MambaStateDtypeCalculator:
@classmethod
def _mamba_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
@@ -63,7 +62,7 @@ class MambaStateDtypeCalculator:
@classmethod
def short_conv_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
@@ -72,7 +71,7 @@ class MambaStateDtypeCalculator:
@classmethod
def gated_delta_net_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)

View File

@@ -4,7 +4,6 @@
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
from typing import Optional, Union
import numpy as np
import torch
@@ -469,17 +468,17 @@ def _causal_conv1d_fwd_kernel( # continuous batching
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Union[torch.Tensor, None],
bias: torch.Tensor | None,
conv_states: torch.Tensor,
query_start_loc: torch.Tensor,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
cache_indices: torch.Tensor | None = None,
has_initial_state: torch.Tensor | None = None,
activation: str | None = "silu",
pad_slot_id: int = PAD_SLOT_ID,
block_idx_first_scheduled_token: Optional[torch.Tensor] = None,
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
num_computed_tokens: Optional[torch.Tensor] = None,
block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
num_computed_tokens: torch.Tensor | None = None,
block_size_to_align=0,
metadata=None,
validate_data=False,
@@ -1071,15 +1070,15 @@ def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Union[bool, str, None] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
activation: bool | str | None = None,
conv_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
query_start_loc: torch.Tensor | None = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
validate_data=False,
):
"""

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -38,8 +38,8 @@ class ShortConv(MambaBase, CustomOp):
config,
dim: int,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
prefix: str = "",
):
super().__init__()