Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -48,7 +49,7 @@ class Llama4MoE(nn.Module):
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
|
||||
# psuedo-standard is that the router scores are floats
|
||||
router_scores = torch.sigmoid(router_scores.float())
|
||||
@@ -115,7 +116,7 @@ class Llama4Attention(nn.Module):
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
@@ -300,7 +301,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
@@ -335,9 +336,9 @@ class Llama4Model(LlamaModel):
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
params_dict: Dict[str, nn.Parameter],
|
||||
loaded_params: Set[str],
|
||||
expert_params_mapping: List[Tuple[str, str, int, str]],
|
||||
params_dict: dict[str, nn.Parameter],
|
||||
loaded_params: set[str],
|
||||
expert_params_mapping: list[tuple[str, str, int, str]],
|
||||
fused: bool = True,
|
||||
) -> bool:
|
||||
expert_param_loaded = False
|
||||
@@ -390,8 +391,8 @@ class Llama4Model(LlamaModel):
|
||||
expert_param_loaded = True
|
||||
return expert_param_loaded
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -412,7 +413,7 @@ class Llama4Model(LlamaModel):
|
||||
ckpt_up_proj_name="gate_up_proj",
|
||||
num_experts=1)
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||
fused_experts_params = True
|
||||
@@ -489,8 +490,8 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
prefix=prefix,
|
||||
layer_type=layer_type)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
@@ -506,7 +507,7 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> Tuple[str, torch.Tensor]:
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
|
||||
Reference in New Issue
Block a user