Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Set
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -66,7 +65,7 @@ class BertEmbedding(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
token_type_ids = _decode_token_type_ids(input_ids)
|
||||
|
||||
@@ -103,9 +102,9 @@ class BertPooler(Pooler):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
@@ -147,8 +146,8 @@ class BertLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: BertConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -191,8 +190,8 @@ class BertAttention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
layer_norm_eps: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -225,8 +224,8 @@ class BertSelfAttention(nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -281,7 +280,7 @@ class BertSelfOutput(nn.Module):
|
||||
self,
|
||||
hidden_size: int,
|
||||
layer_norm_eps: float,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -308,7 +307,7 @@ class BertIntermediate(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -333,7 +332,7 @@ class BertOutput(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_norm_eps: float,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -383,8 +382,8 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
@@ -494,8 +493,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(
|
||||
input_ids=input_ids,
|
||||
@@ -636,11 +635,11 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if token_type_ids is not None:
|
||||
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||
@@ -692,11 +691,11 @@ class BertForTokenClassification(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if token_type_ids is not None:
|
||||
assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||
|
||||
Reference in New Issue
Block a user