Update deprecated type hinting in model_executor/layers (#18056)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -25,7 +26,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for embedding layer."""
|
||||
@@ -141,7 +142,7 @@ def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.compile will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (
|
||||
@@ -298,7 +299,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
added_vocab_start_index, added_vocab_end_index)
|
||||
|
||||
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
|
||||
def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
|
||||
"""Get a mapping that can be used to reindex the gathered
|
||||
logits for sampling.
|
||||
|
||||
@@ -312,9 +313,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
if self.tp_size < 2:
|
||||
return None
|
||||
|
||||
base_embeddings: List[int] = []
|
||||
added_embeddings: List[int] = []
|
||||
padding: List[int] = []
|
||||
base_embeddings: list[int] = []
|
||||
added_embeddings: list[int] = []
|
||||
padding: list[int] = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
|
||||
Reference in New Issue
Block a user