Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,7 +12,7 @@ class DummyLoRAManager:
|
||||
|
||||
def __init__(self, device: torch.device = "cuda:0"):
|
||||
super().__init__()
|
||||
self._loras: Dict[str, LoRALayerWeights] = {}
|
||||
self._loras: dict[str, LoRALayerWeights] = {}
|
||||
self._device = device
|
||||
|
||||
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
||||
@@ -77,11 +77,11 @@ class DummyLoRAManager:
|
||||
self,
|
||||
module_name: str,
|
||||
input_dim: int,
|
||||
output_dims: List[int],
|
||||
noop_lora_index: Optional[List[int]] = None,
|
||||
output_dims: list[int],
|
||||
noop_lora_index: Optional[list[int]] = None,
|
||||
rank: int = 8,
|
||||
):
|
||||
base_loras: List[LoRALayerWeights] = []
|
||||
base_loras: list[LoRALayerWeights] = []
|
||||
noop_lora_index_set = set(noop_lora_index or [])
|
||||
|
||||
for i, out_dim in enumerate(output_dims):
|
||||
@@ -110,7 +110,7 @@ def assert_close(a, b):
|
||||
@dataclass
|
||||
class PunicaTensors:
|
||||
inputs_tensor: torch.Tensor
|
||||
lora_weights: Union[torch.Tensor, List[torch.Tensor]]
|
||||
lora_weights: Union[torch.Tensor, list[torch.Tensor]]
|
||||
our_out_tensor: torch.Tensor
|
||||
ref_out_tensor: torch.Tensor
|
||||
b_seq_start_loc: torch.Tensor
|
||||
@@ -118,7 +118,7 @@ class PunicaTensors:
|
||||
seq_len_tensor: torch.Tensor
|
||||
token_lora_mapping: torch.Tensor
|
||||
|
||||
def meta(self) -> Tuple[int, int]:
|
||||
def meta(self) -> tuple[int, int]:
|
||||
"""
|
||||
Infer max_seq_length and token_nums from the tensors
|
||||
and return them.
|
||||
|
||||
Reference in New Issue
Block a user