Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,8 +23,12 @@ from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
|
||||
InputDefinition, InputTypeEnum)
|
||||
from terratorch.vllm import (
|
||||
DummyDataGenerator,
|
||||
InferenceRunner,
|
||||
InputDefinition,
|
||||
InputTypeEnum,
|
||||
)
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -35,19 +39,31 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargsItems,
|
||||
MultiModalUUIDDict, PlaceholderRange)
|
||||
from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems,
|
||||
MultiModalDataItems, MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptUpdate)
|
||||
from vllm.multimodal.inputs import (
|
||||
ImageItem,
|
||||
ModalityData,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalUUIDDict,
|
||||
PlaceholderRange,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
DictEmbeddingItems,
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
|
||||
SupportsMultiModal)
|
||||
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -59,12 +75,11 @@ def _terratorch_field_names(pretrained_cfg: dict):
|
||||
|
||||
|
||||
def _terratorch_field_factory(
|
||||
pretrained_cfg: dict
|
||||
pretrained_cfg: dict,
|
||||
) -> Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
]:
|
||||
|
||||
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
fields = {}
|
||||
@@ -75,24 +90,24 @@ def _terratorch_field_factory(
|
||||
mm_fields_config = {}
|
||||
for field_name, field_modality in fields.items():
|
||||
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
|
||||
batch_size=1, modality=field_modality)
|
||||
batch_size=1, modality=field_modality
|
||||
)
|
||||
return mm_fields_config
|
||||
|
||||
return _terratorch_field_config
|
||||
|
||||
|
||||
class TerratorchProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
|
||||
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
|
||||
|
||||
def __init__(self, info: TerratorchProcessingInfo):
|
||||
super().__init__(info)
|
||||
self.dummy_data_generator = DummyDataGenerator(
|
||||
self.info.get_hf_config().to_dict()["pretrained_cfg"])
|
||||
self.info.get_hf_config().to_dict()["pretrained_cfg"]
|
||||
)
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
@@ -107,15 +122,16 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
|
||||
# defined in the HF configuration file
|
||||
|
||||
if mm_options:
|
||||
logger.warning("Configurable multimodal profiling "
|
||||
"options are not supported for Terratorch. "
|
||||
"They are ignored for now.")
|
||||
logger.warning(
|
||||
"Configurable multimodal profiling "
|
||||
"options are not supported for Terratorch. "
|
||||
"They are ignored for now."
|
||||
)
|
||||
|
||||
return self.dummy_data_generator.get_dummy_mm_data()
|
||||
|
||||
|
||||
class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
|
||||
self._pretrained_cfg = pretrained_cfg
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -125,7 +141,6 @@ class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
|
||||
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
|
||||
|
||||
return DictEmbeddingItems(
|
||||
@@ -139,20 +154,18 @@ class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
|
||||
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
info: TerratorchProcessingInfo,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
|
||||
*,
|
||||
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
|
||||
|
||||
self,
|
||||
info: TerratorchProcessingInfo,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
|
||||
*,
|
||||
cache: Optional[MultiModalProcessorOnlyCache] = None,
|
||||
) -> None:
|
||||
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
|
||||
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return TerratorchMultiModalDataParser(
|
||||
pretrained_cfg=self.pretrained_cfg)
|
||||
return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@@ -185,18 +198,16 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
mm_hashes = self._hash_mm_items(mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids)
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
|
||||
)
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
mm_processed_data = BatchFeature(image_data)
|
||||
|
||||
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
mm_processed_data,
|
||||
self._get_mm_fields_config(mm_processed_data,
|
||||
hf_processor_mm_kwargs),
|
||||
self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
|
||||
)
|
||||
|
||||
return MultiModalInputs(
|
||||
@@ -237,7 +248,8 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -265,8 +277,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
|
||||
return model_output.output
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
params_list = []
|
||||
model_buffers = dict(self.named_buffers())
|
||||
loaded_buffers = []
|
||||
@@ -289,8 +300,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
buffer = model_buffers[name]
|
||||
weight_loader = getattr(buffer, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(
|
||||
buffer, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(buffer, weight)
|
||||
loaded_buffers.append(name)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user