Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -6,8 +6,9 @@ https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer
architectures in a hybrid model optimized for efficient sequence modeling. The
model alternates between state space model layers and attention-based layers.
"""
from collections.abc import Iterable
from itertools import cycle
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
@@ -54,7 +55,7 @@ class Zamba2LoRA(nn.Module):
self,
input_dim: int,
rank: int,
output_dim: Union[int, List[int]],
output_dim: Union[int, list[int]],
quant_config: Optional[QuantizationConfig] = None,
):
"""Initialize the attention layer.
@@ -279,7 +280,7 @@ class Zamba2MLP(nn.Module):
self,
config: Zamba2Config,
bare_block_idx: int,
num_hybrid_layers: Dict[int, int],
num_hybrid_layers: dict[int, int],
quant_config: Optional[QuantizationConfig] = None,
) -> None:
"""Initialize the MLP layer.
@@ -769,8 +770,8 @@ class Zamba2Model(nn.Module):
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -779,7 +780,7 @@ class Zamba2Model(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for chkpt_weight_name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in chkpt_weight_name:
@@ -914,9 +915,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers: Dict[str,
def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str,
torch.Tensor],
**kwargs) -> Dict[str, torch.Tensor]:
**kwargs) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture.
Args:
@@ -930,7 +931,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(
self, batch_size: int) -> Dict[str, torch.Tensor]:
self, batch_size: int) -> dict[str, torch.Tensor]:
"""Get inputs for sequence-length-agnostic graph capture.
Args:
@@ -941,7 +942,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Returns:
@@ -1001,7 +1002,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)