Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Inference-only Jamba model."""
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -442,7 +443,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
hidden_size = self.config.hidden_size
|
||||
conv_state_shape = (
|
||||
@@ -464,8 +465,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -482,7 +483,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@@ -583,7 +584,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
logits = self.score(hidden_states)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# TODO: The reward weights themselves have float32 accuracy data, we
|
||||
# would like to load them in fp32 to get that extra precision.
|
||||
super().load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user