Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Hashable
|
||||
from collections.abc import Callable, Hashable
|
||||
from fractions import Fraction
|
||||
from typing import Callable, Optional, Union
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@@ -36,7 +35,7 @@ class BasevLLMParameter(Parameter):
|
||||
into the parameter when the provided weight loader is called.
|
||||
"""
|
||||
|
||||
def __new__(cls, data: Optional[torch.Tensor], **kwargs):
|
||||
def __new__(cls, data: torch.Tensor | None, **kwargs):
|
||||
return super().__new__(cls, data=data, requires_grad=False)
|
||||
|
||||
def __init__(self, data: torch.Tensor, weight_loader: Callable):
|
||||
@@ -109,7 +108,7 @@ class BasevLLMParameter(Parameter):
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
self._assert_and_load(loaded_weight)
|
||||
|
||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||
def _shard_id_as_int(self, shard_id: str | int) -> int:
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
@@ -290,7 +289,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def _load_into_shard_id(
|
||||
self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs
|
||||
self, loaded_weight: torch.Tensor, shard_id: str | int, **kwargs
|
||||
):
|
||||
"""
|
||||
Slice the parameter data based on the shard id for
|
||||
@@ -320,10 +319,10 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_factor: int | Fraction,
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
marlin_tile_size: int | None = None,
|
||||
bitblas_tile_size: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
@@ -371,10 +370,10 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_factor: int | Fraction,
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
marlin_tile_size: int | None = None,
|
||||
bitblas_tile_size: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
@@ -437,7 +436,7 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
local_tensors: set[torch.Tensor]
|
||||
|
||||
# dictionary mapping partition indices to associated parameters
|
||||
partitions: dict[int, Union[ModelWeightParameter, Parameter]]
|
||||
partitions: dict[int, ModelWeightParameter | Parameter]
|
||||
|
||||
def __new__(cls, **kwargs):
|
||||
return super().__new__(cls, data=None, **kwargs)
|
||||
@@ -547,7 +546,7 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_weight_shard_id: Optional[Union[str, int]],
|
||||
loaded_weight_shard_id: str | int | None,
|
||||
):
|
||||
raise ValueError(
|
||||
"When loading partition weights of "
|
||||
|
||||
Reference in New Issue
Block a user