Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -23,8 +23,12 @@ from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
InputDefinition, InputTypeEnum)
from terratorch.vllm import (
DummyDataGenerator,
InferenceRunner,
InputDefinition,
InputTypeEnum,
)
from transformers import BatchFeature
from vllm.config import VllmConfig
@@ -35,19 +39,31 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems,
MultiModalUUIDDict, PlaceholderRange)
from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.inputs import (
ImageItem,
ModalityData,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalInputs,
MultiModalKwargsItems,
MultiModalUUIDDict,
PlaceholderRange,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
SupportsMultiModal)
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import default_pooling_type
logger = init_logger(__name__)
@@ -59,12 +75,11 @@ def _terratorch_field_names(pretrained_cfg: dict):
def _terratorch_field_factory(
pretrained_cfg: dict
pretrained_cfg: dict,
) -> Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
Mapping[str, MultiModalFieldConfig],
]:
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
input_definition = InputDefinition(**pretrained_cfg["input"])
fields = {}
@@ -75,24 +90,24 @@ def _terratorch_field_factory(
mm_fields_config = {}
for field_name, field_modality in fields.items():
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
batch_size=1, modality=field_modality)
batch_size=1, modality=field_modality
)
return mm_fields_config
return _terratorch_field_config
class TerratorchProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
def __init__(self, info: TerratorchProcessingInfo):
super().__init__(info)
self.dummy_data_generator = DummyDataGenerator(
self.info.get_hf_config().to_dict()["pretrained_cfg"])
self.info.get_hf_config().to_dict()["pretrained_cfg"]
)
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
@@ -107,15 +122,16 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
# defined in the HF configuration file
if mm_options:
logger.warning("Configurable multimodal profiling "
"options are not supported for Terratorch. "
"They are ignored for now.")
logger.warning(
"Configurable multimodal profiling "
"options are not supported for Terratorch. "
"They are ignored for now."
)
return self.dummy_data_generator.get_dummy_mm_data()
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
self._pretrained_cfg = pretrained_cfg
super().__init__(*args, **kwargs)
@@ -125,7 +141,6 @@ class TerratorchMultiModalDataParser(MultiModalDataParser):
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
return DictEmbeddingItems(
@@ -139,20 +154,18 @@ class TerratorchMultiModalDataParser(MultiModalDataParser):
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
self,
info: TerratorchProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
*,
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
self,
info: TerratorchProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
*,
cache: Optional[MultiModalProcessorOnlyCache] = None,
) -> None:
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser:
return TerratorchMultiModalDataParser(
pretrained_cfg=self.pretrained_cfg)
return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg)
def _get_mm_fields_config(
self,
@@ -185,18 +198,16 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {}
mm_hashes = self._hash_mm_items(mm_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids)
mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
self._get_mm_fields_config(mm_processed_data,
hf_processor_mm_kwargs),
self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
)
return MultiModalInputs(
@@ -237,7 +248,8 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
{"encode": Pooler.for_encode(pooler_config)},
)
def get_input_embeddings(
self,
@@ -265,8 +277,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
return model_output.output
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_list = []
model_buffers = dict(self.named_buffers())
loaded_buffers = []
@@ -289,8 +300,9 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
buffer = model_buffers[name]
weight_loader = getattr(buffer, "weight_loader",
default_weight_loader)
weight_loader = getattr(
buffer, "weight_loader", default_weight_loader
)
weight_loader(buffer, weight)
loaded_buffers.append(name)
else: