Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -135,7 +136,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
prefix=prefix,
|
||||
embedding_class=RobertaEmbedding)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
# Separate weights in "roberta"-prefixed and all else (not in memory).
|
||||
# For use with models like FacebookAI/roberta-base.
|
||||
@@ -187,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
self._pooler = CrossEncodingPooler(config, self.classifier)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
bert_weights, task_weights = roberta_task_weights_filter(weights)
|
||||
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
|
||||
|
||||
@@ -249,8 +250,8 @@ def create_position_ids_from_input_ids(input_ids,
|
||||
|
||||
|
||||
def roberta_task_weights_filter(
|
||||
all_weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
|
||||
all_weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str,
|
||||
torch.Tensor]]]:
|
||||
"""
|
||||
Separate task-specific weights that are applied on top
|
||||
|
||||
Reference in New Issue
Block a user