Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:17:23 +01:00
committed by GitHub
parent 906f0598fc
commit 6223dd8114
87 changed files with 523 additions and 523 deletions

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
@@ -25,7 +26,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for embedding layer."""
@@ -141,7 +142,7 @@ def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
@@ -298,7 +299,7 @@ class VocabParallelEmbedding(torch.nn.Module):
org_vocab_start_index, org_vocab_end_index,
added_vocab_start_index, added_vocab_end_index)
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
@@ -312,9 +313,9 @@ class VocabParallelEmbedding(torch.nn.Module):
if self.tp_size < 2:
return None
base_embeddings: List[int] = []
added_embeddings: List[int] = []
padding: List[int] = []
base_embeddings: list[int] = []
added_embeddings: list[int] = []
padding: list[int] = []
for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,