Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Inference-only Snowflake Arctic model."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -458,8 +459,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -467,8 +468,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
mlp_params_mapping: List[Tuple[str, str, int]] = []
|
||||
expert_params_mapping: List[Tuple[str, str, int]] = []
|
||||
mlp_params_mapping: list[tuple[str, str, int]] = []
|
||||
expert_params_mapping: list[tuple[str, str, int]] = []
|
||||
num_layers = self.config.num_hidden_layers
|
||||
|
||||
for layer in range(num_layers):
|
||||
@@ -497,7 +498,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
("ws", f"experts.{expert_id}.w3.weight", expert_id))
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
logger.info(
|
||||
"It will take ~10 minutes loading from the 16-bit weights. "
|
||||
|
||||
Reference in New Issue
Block a user