Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,8 +6,9 @@ https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer
|
||||
architectures in a hybrid model optimized for efficient sequence modeling. The
|
||||
model alternates between state space model layers and attention-based layers.
|
||||
"""
|
||||
from collections.abc import Iterable
|
||||
from itertools import cycle
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -54,7 +55,7 @@ class Zamba2LoRA(nn.Module):
|
||||
self,
|
||||
input_dim: int,
|
||||
rank: int,
|
||||
output_dim: Union[int, List[int]],
|
||||
output_dim: Union[int, list[int]],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
"""Initialize the attention layer.
|
||||
@@ -279,7 +280,7 @@ class Zamba2MLP(nn.Module):
|
||||
self,
|
||||
config: Zamba2Config,
|
||||
bare_block_idx: int,
|
||||
num_hybrid_layers: Dict[int, int],
|
||||
num_hybrid_layers: dict[int, int],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
"""Initialize the MLP layer.
|
||||
@@ -769,8 +770,8 @@ class Zamba2Model(nn.Module):
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -779,7 +780,7 @@ class Zamba2Model(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
loaded_params: set[str] = set()
|
||||
for chkpt_weight_name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in chkpt_weight_name:
|
||||
@@ -914,9 +915,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers: Dict[str,
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str,
|
||||
torch.Tensor],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
**kwargs) -> dict[str, torch.Tensor]:
|
||||
"""Copy inputs before CUDA graph capture.
|
||||
|
||||
Args:
|
||||
@@ -930,7 +931,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(
|
||||
self, batch_size: int) -> Dict[str, torch.Tensor]:
|
||||
self, batch_size: int) -> dict[str, torch.Tensor]:
|
||||
"""Get inputs for sequence-length-agnostic graph capture.
|
||||
|
||||
Args:
|
||||
@@ -941,7 +942,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
self) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
"""Calculate shapes for Mamba's convolutional and state caches.
|
||||
|
||||
Returns:
|
||||
@@ -1001,7 +1002,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user