Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Union
import torch
@@ -12,7 +12,7 @@ class DummyLoRAManager:
def __init__(self, device: torch.device = "cuda:0"):
super().__init__()
self._loras: Dict[str, LoRALayerWeights] = {}
self._loras: dict[str, LoRALayerWeights] = {}
self._device = device
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
@@ -77,11 +77,11 @@ class DummyLoRAManager:
self,
module_name: str,
input_dim: int,
output_dims: List[int],
noop_lora_index: Optional[List[int]] = None,
output_dims: list[int],
noop_lora_index: Optional[list[int]] = None,
rank: int = 8,
):
base_loras: List[LoRALayerWeights] = []
base_loras: list[LoRALayerWeights] = []
noop_lora_index_set = set(noop_lora_index or [])
for i, out_dim in enumerate(output_dims):
@@ -110,7 +110,7 @@ def assert_close(a, b):
@dataclass
class PunicaTensors:
inputs_tensor: torch.Tensor
lora_weights: Union[torch.Tensor, List[torch.Tensor]]
lora_weights: Union[torch.Tensor, list[torch.Tensor]]
our_out_tensor: torch.Tensor
ref_out_tensor: torch.Tensor
b_seq_start_loc: torch.Tensor
@@ -118,7 +118,7 @@ class PunicaTensors:
seq_len_tensor: torch.Tensor
token_lora_mapping: torch.Tensor
def meta(self) -> Tuple[int, int]:
def meta(self) -> tuple[int, int]:
"""
Infer max_seq_length and token_nums from the tensors
and return them.