Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -13,11 +13,10 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def adapter_cache(request, tmpdir_factory):
# Create dir that mimics the structure of the adapter cache
adapter_cache = tmpdir_factory.mktemp(
request.module.__name__) / "adapter_cache"
adapter_cache = tmpdir_factory.mktemp(request.module.__name__) / "adapter_cache"
return adapter_cache

View File

@@ -20,14 +20,15 @@ from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import (IOProcessorRequest,
IOProcessorResponse)
from vllm.entrypoints.openai.protocol import IOProcessorRequest, IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (IOProcessor,
IOProcessorInput,
IOProcessorOutput)
from vllm.plugins.io_processors.interface import (
IOProcessor,
IOProcessorInput,
IOProcessorOutput,
)
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
@@ -42,35 +43,25 @@ DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5]
datamodule_config: DataModuleConfig = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size":
16,
"constant_scale":
0.0001,
"data_root":
"/dccstor/geofm-finetuning/datasets/sen1floods11",
"drop_last":
True,
"no_data_replace":
0.0,
"no_label_replace":
-1,
"num_workers":
8,
"batch_size": 16,
"constant_scale": 0.0001,
"data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
"drop_last": True,
"no_data_replace": 0.0,
"no_label_replace": -1,
"num_workers": 8,
"test_transform": [
albumentations.Resize(always_apply=False,
height=448,
interpolation=1,
p=1,
width=448),
albumentations.pytorch.ToTensorV2(transpose_mask=False,
always_apply=True,
p=1.0),
albumentations.Resize(
always_apply=False, height=448, interpolation=1, p=1, width=448
),
albumentations.pytorch.ToTensorV2(
transpose_mask=False, always_apply=True, p=1.0
),
],
}
def save_geotiff(image: torch.Tensor, meta: dict,
out_format: str) -> str | bytes:
def save_geotiff(image: torch.Tensor, meta: dict, out_format: str) -> str | bytes:
"""Save multi-band image in Geotiff file.
Args:
@@ -219,8 +210,11 @@ def load_image(
if len(julian_day) == 3:
julian_day = int(julian_day)
else:
julian_day = (datetime.datetime.strptime(
julian_day, "%m%d").timetuple().tm_yday)
julian_day = (
datetime.datetime.strptime(julian_day, "%m%d")
.timetuple()
.tm_yday
)
temporal_coords.append([year, julian_day])
except Exception:
logger.exception("Could not extract timestamp for %s", file)
@@ -233,11 +227,9 @@ def load_image(
class PrithviMultimodalDataProcessor(IOProcessor):
indices = [0, 1, 2, 3, 4, 5]
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.datamodule = Sen1Floods11NonGeoDataModule(
@@ -264,8 +256,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
return image_prompt
if isinstance(request, IOProcessorRequest):
if not hasattr(request, "data"):
raise ValueError(
"missing 'data' field in OpenAIBaseModel Request")
raise ValueError("missing 'data' field in OpenAIBaseModel Request")
request_data = request.data
@@ -277,7 +268,8 @@ class PrithviMultimodalDataProcessor(IOProcessor):
raise ValueError("Unable to parse request")
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
@@ -289,7 +281,6 @@ class PrithviMultimodalDataProcessor(IOProcessor):
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
image_data = dict(prompt)
if request_id:
@@ -309,10 +300,8 @@ class PrithviMultimodalDataProcessor(IOProcessor):
input_data = input_data / 10000 # Convert to range 0-1
self.original_h, self.original_w = input_data.shape[-2:]
pad_h = (self.img_size -
(self.original_h % self.img_size)) % self.img_size
pad_w = (self.img_size -
(self.original_w % self.img_size)) % self.img_size
pad_h = (self.img_size - (self.original_h % self.img_size)) % self.img_size
pad_w = (self.img_size - (self.original_w % self.img_size)) % self.img_size
input_data = np.pad(
input_data,
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
@@ -320,9 +309,9 @@ class PrithviMultimodalDataProcessor(IOProcessor):
)
batch = torch.tensor(input_data)
windows = batch.unfold(3, self.img_size,
self.img_size).unfold(4, self.img_size,
self.img_size)
windows = batch.unfold(3, self.img_size, self.img_size).unfold(
4, self.img_size, self.img_size
)
self.h1, self.w1 = windows.shape[3:5]
windows = rearrange(
windows,
@@ -332,8 +321,11 @@ class PrithviMultimodalDataProcessor(IOProcessor):
)
# Split into batches if number of windows > batch_size
num_batches = (windows.shape[0] // self.batch_size
if windows.shape[0] > self.batch_size else 1)
num_batches = (
windows.shape[0] // self.batch_size
if windows.shape[0] > self.batch_size
else 1
)
windows = torch.tensor_split(windows, num_batches, dim=0)
if temporal_coords:
@@ -349,15 +341,18 @@ class PrithviMultimodalDataProcessor(IOProcessor):
for window in windows:
# Apply standardization
window = self.datamodule.test_transform(
image=window.squeeze().numpy().transpose(1, 2, 0))
image=window.squeeze().numpy().transpose(1, 2, 0)
)
window = self.datamodule.aug(window)["image"]
prompts.append({
"prompt_token_ids": [1],
"multi_modal_data": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
},
})
prompts.append(
{
"prompt_token_ids": [1],
"multi_modal_data": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
},
}
)
return prompts
@@ -367,7 +362,6 @@ class PrithviMultimodalDataProcessor(IOProcessor):
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
pred_imgs_list = []
if request_id and (request_id in self.requests_cache):
@@ -399,7 +393,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
)
# Cut padded area back to original size
pred_imgs = pred_imgs[..., :self.original_h, :self.original_w]
pred_imgs = pred_imgs[..., : self.original_h, : self.original_w]
# Squeeze (batch size 1)
pred_imgs = pred_imgs[0]
@@ -407,10 +401,10 @@ class PrithviMultimodalDataProcessor(IOProcessor):
if not self.meta_data:
raise ValueError("No metadata available for the current task")
self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data,
out_format)
out_data = save_geotiff(
_convert_np_uint8(pred_imgs), self.meta_data, out_format
)
return ImageRequestOutput(type=out_format,
format="tiff",
data=out_data,
request_id=request_id)
return ImageRequestOutput(
type=out_format, format="tiff", data=out_data, request_id=request_id
)

View File

@@ -16,12 +16,10 @@ class DataModuleConfig(TypedDict):
no_data_replace: float
no_label_replace: int
num_workers: int
test_transform: list[
albumentations.core.transforms_interface.BasicTransform]
test_transform: list[albumentations.core.transforms_interface.BasicTransform]
class ImagePrompt(BaseModel):
data_format: Literal["b64_json", "bytes", "url", "path"]
"""
This is the data type for the input image
@@ -45,7 +43,7 @@ MultiModalPromptType = Union[ImagePrompt]
class ImageRequestOutput(BaseModel):
"""
The output data of an image request to vLLM.
The output data of an image request to vLLM.
Args:
type (str): The data content type [path, object]

View File

@@ -3,10 +3,11 @@
from setuptools import setup
setup(name='vllm_add_dummy_model',
version='0.1',
packages=['vllm_add_dummy_model'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_model:register"]
})
setup(
name="vllm_add_dummy_model",
version="0.1",
packages=["vllm_add_dummy_model"],
entry_points={
"vllm.general_plugins": ["register_dummy_model = vllm_add_dummy_model:register"]
},
)

View File

@@ -19,5 +19,4 @@ def register():
)
if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava",
"vllm_add_dummy_model.my_llava:MyLlava")
ModelRegistry.register_model("MyLlava", "vllm_add_dummy_model.my_llava:MyLlava")

View File

@@ -15,7 +15,6 @@ from vllm.sequence import IntermediateTensors
class MyGemma2Embedding(nn.Module):
is_pooling_model = True
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
@@ -23,19 +22,23 @@ class MyGemma2Embedding(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = Gemma2Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
})
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.model.make_empty_intermediate_tensors
)
def forward(
self,
@@ -58,8 +61,8 @@ class MyGemma2Embedding(nn.Module):
return torch.zeros_like(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
weights = (
(name, data) for name, data in weights if not name.startswith("lm_head.")
)
return self.model.load_weights(weights)

View File

@@ -5,20 +5,22 @@ from typing import Optional
import torch
from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder,
LlavaForConditionalGeneration,
LlavaMultiModalProcessor,
LlavaProcessingInfo)
from vllm.model_executor.models.llava import (
LlavaDummyInputsBuilder,
LlavaForConditionalGeneration,
LlavaMultiModalProcessor,
LlavaProcessingInfo,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
@MULTIMODAL_REGISTRY.register_processor(
LlavaMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder,
)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states)
if logits is not None:

View File

@@ -9,9 +9,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states)
if logits is not None:

View File

@@ -4,13 +4,15 @@
from setuptools import setup
setup(
name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
name="vllm_add_dummy_platform",
version="0.1",
packages=["vllm_add_dummy_platform"],
entry_points={
'vllm.platform_plugins': [
"vllm.platform_plugins": [
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
],
"vllm.general_plugins":
["dummy_custom_ops = vllm_add_dummy_platform:register_ops"],
})
"vllm.general_plugins": [
"dummy_custom_ops = vllm_add_dummy_platform:register_ops"
],
},
)

View File

@@ -1,12 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend
class DummyAttentionBackend(PlaceholderAttentionBackend):
@staticmethod
def get_name() -> str:
return "Dummy_Backend"

View File

@@ -15,6 +15,5 @@ class DummyRotaryEmbedding(RotaryEmbedding):
super().__init__(*args, **kwargs)
self.addition_config = True
def forward_oot(self, *args,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
def forward_oot(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
return super().forward_oot(*args, **kwargs)

View File

@@ -24,7 +24,16 @@ class DummyPlatform(Platform):
# Activate custom ops for v1.
compilation_config.custom_ops = ["all"]
def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink, use_sparse):
def get_attn_backend_cls(
self,
backend_name,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_v1,
use_mla,
has_sink,
use_sparse,
):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501