Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -43,10 +43,9 @@
import copy
import math
from collections.abc import Mapping
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple,
TypedDict, Union)
from typing import Any, Literal, Optional, TypedDict, Union
import torch
from torch import nn
@@ -120,7 +119,7 @@ class KimiVLMultiModalProjector(nn.Module):
class KimiVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, List[torch.Tensor]]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:`(num_patches, num_channels, patch_size, patch_size)`
"""
@@ -447,7 +446,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata, **kwargs)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",