Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@@ -87,8 +87,8 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert residual is None, "RMSNorm does not support residual connection."
|
||||
return self._forward(x)
|
||||
|
||||
@@ -102,7 +102,7 @@ class MiniMaxText01LinearKernel:
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
layer_idx: Optional[int] = None,
|
||||
layer_idx: int | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
slope_rate = slope_rate.to(torch.float32)
|
||||
@@ -154,9 +154,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
max_position: int,
|
||||
block_size: int,
|
||||
num_hidden_layer: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
layer_idx: int = 0,
|
||||
linear_layer_idx: int = 0,
|
||||
prefix: str = "linear_attn",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@@ -68,8 +68,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -410,7 +410,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||
def _time_proj_bias(self) -> torch.Tensor | None:
|
||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||
return self.dt_proj.bias.float()
|
||||
return None
|
||||
@@ -423,8 +423,8 @@ class PrefillDecodeSplit(NamedTuple):
|
||||
gate_d: torch.Tensor
|
||||
state_indices_tensor_p: torch.Tensor
|
||||
state_indices_tensor_d: torch.Tensor
|
||||
query_start_loc_p: Optional[torch.Tensor]
|
||||
has_initial_states_p: Optional[torch.Tensor]
|
||||
query_start_loc_p: torch.Tensor | None
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
|
||||
|
||||
def split_batch_to_prefill_and_decode(
|
||||
@@ -432,7 +432,7 @@ def split_batch_to_prefill_and_decode(
|
||||
gate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
has_initial_states: Optional[torch.Tensor],
|
||||
has_initial_states: torch.Tensor | None,
|
||||
num_prefill_tokens: int,
|
||||
num_decode_tokens: int,
|
||||
num_prefills: int,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@@ -138,7 +138,7 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
@@ -244,9 +244,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -474,7 +474,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -482,7 +482,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
@@ -495,7 +495,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||
@@ -904,7 +904,7 @@ def mamba_mixer2(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
@@ -915,7 +915,7 @@ def mamba_mixer2_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -14,7 +13,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
# TODO (tdoublep) requires testing
|
||||
@@ -26,7 +25,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def mamba1_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
@@ -37,7 +36,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def mamba2_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
@@ -48,7 +47,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def _mamba_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
@@ -63,7 +62,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def short_conv_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
@@ -72,7 +71,7 @@ class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def gated_delta_net_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -469,17 +468,17 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Union[torch.Tensor, None],
|
||||
bias: torch.Tensor | None,
|
||||
conv_states: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
cache_indices: torch.Tensor | None = None,
|
||||
has_initial_state: torch.Tensor | None = None,
|
||||
activation: str | None = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
block_idx_first_scheduled_token: Optional[torch.Tensor] = None,
|
||||
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
|
||||
initial_state_idx: Optional[torch.Tensor] = None,
|
||||
num_computed_tokens: Optional[torch.Tensor] = None,
|
||||
block_idx_first_scheduled_token: torch.Tensor | None = None,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
initial_state_idx: torch.Tensor | None = None,
|
||||
num_computed_tokens: torch.Tensor | None = None,
|
||||
block_size_to_align=0,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
@@ -1071,15 +1070,15 @@ def causal_conv1d_update(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Union[bool, str, None] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
activation: bool | str | None = None,
|
||||
conv_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
query_start_loc: torch.Tensor | None = None,
|
||||
max_query_len: int = -1,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
|
||||
initial_state_idx: Optional[torch.Tensor] = None,
|
||||
block_idx_last_scheduled_token: torch.Tensor | None = None,
|
||||
initial_state_idx: torch.Tensor | None = None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@@ -38,8 +38,8 @@ class ShortConv(MambaBase, CustomOp):
|
||||
config,
|
||||
dim: int,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user