Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -20,14 +20,15 @@ from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import (IOProcessorRequest,
IOProcessorResponse)
from vllm.entrypoints.openai.protocol import IOProcessorRequest, IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (IOProcessor,
IOProcessorInput,
IOProcessorOutput)
from vllm.plugins.io_processors.interface import (
IOProcessor,
IOProcessorInput,
IOProcessorOutput,
)
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
@@ -42,35 +43,25 @@ DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5]
datamodule_config: DataModuleConfig = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size":
16,
"constant_scale":
0.0001,
"data_root":
"/dccstor/geofm-finetuning/datasets/sen1floods11",
"drop_last":
True,
"no_data_replace":
0.0,
"no_label_replace":
-1,
"num_workers":
8,
"batch_size": 16,
"constant_scale": 0.0001,
"data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
"drop_last": True,
"no_data_replace": 0.0,
"no_label_replace": -1,
"num_workers": 8,
"test_transform": [
albumentations.Resize(always_apply=False,
height=448,
interpolation=1,
p=1,
width=448),
albumentations.pytorch.ToTensorV2(transpose_mask=False,
always_apply=True,
p=1.0),
albumentations.Resize(
always_apply=False, height=448, interpolation=1, p=1, width=448
),
albumentations.pytorch.ToTensorV2(
transpose_mask=False, always_apply=True, p=1.0
),
],
}
def save_geotiff(image: torch.Tensor, meta: dict,
out_format: str) -> str | bytes:
def save_geotiff(image: torch.Tensor, meta: dict, out_format: str) -> str | bytes:
"""Save multi-band image in Geotiff file.
Args:
@@ -219,8 +210,11 @@ def load_image(
if len(julian_day) == 3:
julian_day = int(julian_day)
else:
julian_day = (datetime.datetime.strptime(
julian_day, "%m%d").timetuple().tm_yday)
julian_day = (
datetime.datetime.strptime(julian_day, "%m%d")
.timetuple()
.tm_yday
)
temporal_coords.append([year, julian_day])
except Exception:
logger.exception("Could not extract timestamp for %s", file)
@@ -233,11 +227,9 @@ def load_image(
class PrithviMultimodalDataProcessor(IOProcessor):
indices = [0, 1, 2, 3, 4, 5]
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.datamodule = Sen1Floods11NonGeoDataModule(
@@ -264,8 +256,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
return image_prompt
if isinstance(request, IOProcessorRequest):
if not hasattr(request, "data"):
raise ValueError(
"missing 'data' field in OpenAIBaseModel Request")
raise ValueError("missing 'data' field in OpenAIBaseModel Request")
request_data = request.data
@@ -277,7 +268,8 @@ class PrithviMultimodalDataProcessor(IOProcessor):
raise ValueError("Unable to parse request")
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
@@ -289,7 +281,6 @@ class PrithviMultimodalDataProcessor(IOProcessor):
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
image_data = dict(prompt)
if request_id:
@@ -309,10 +300,8 @@ class PrithviMultimodalDataProcessor(IOProcessor):
input_data = input_data / 10000 # Convert to range 0-1
self.original_h, self.original_w = input_data.shape[-2:]
pad_h = (self.img_size -
(self.original_h % self.img_size)) % self.img_size
pad_w = (self.img_size -
(self.original_w % self.img_size)) % self.img_size
pad_h = (self.img_size - (self.original_h % self.img_size)) % self.img_size
pad_w = (self.img_size - (self.original_w % self.img_size)) % self.img_size
input_data = np.pad(
input_data,
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
@@ -320,9 +309,9 @@ class PrithviMultimodalDataProcessor(IOProcessor):
)
batch = torch.tensor(input_data)
windows = batch.unfold(3, self.img_size,
self.img_size).unfold(4, self.img_size,
self.img_size)
windows = batch.unfold(3, self.img_size, self.img_size).unfold(
4, self.img_size, self.img_size
)
self.h1, self.w1 = windows.shape[3:5]
windows = rearrange(
windows,
@@ -332,8 +321,11 @@ class PrithviMultimodalDataProcessor(IOProcessor):
)
# Split into batches if number of windows > batch_size
num_batches = (windows.shape[0] // self.batch_size
if windows.shape[0] > self.batch_size else 1)
num_batches = (
windows.shape[0] // self.batch_size
if windows.shape[0] > self.batch_size
else 1
)
windows = torch.tensor_split(windows, num_batches, dim=0)
if temporal_coords:
@@ -349,15 +341,18 @@ class PrithviMultimodalDataProcessor(IOProcessor):
for window in windows:
# Apply standardization
window = self.datamodule.test_transform(
image=window.squeeze().numpy().transpose(1, 2, 0))
image=window.squeeze().numpy().transpose(1, 2, 0)
)
window = self.datamodule.aug(window)["image"]
prompts.append({
"prompt_token_ids": [1],
"multi_modal_data": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
},
})
prompts.append(
{
"prompt_token_ids": [1],
"multi_modal_data": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
},
}
)
return prompts
@@ -367,7 +362,6 @@ class PrithviMultimodalDataProcessor(IOProcessor):
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
pred_imgs_list = []
if request_id and (request_id in self.requests_cache):
@@ -399,7 +393,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
)
# Cut padded area back to original size
pred_imgs = pred_imgs[..., :self.original_h, :self.original_w]
pred_imgs = pred_imgs[..., : self.original_h, : self.original_w]
# Squeeze (batch size 1)
pred_imgs = pred_imgs[0]
@@ -407,10 +401,10 @@ class PrithviMultimodalDataProcessor(IOProcessor):
if not self.meta_data:
raise ValueError("No metadata available for the current task")
self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data,
out_format)
out_data = save_geotiff(
_convert_np_uint8(pred_imgs), self.meta_data, out_format
)
return ImageRequestOutput(type=out_format,
format="tiff",
data=out_data,
request_id=request_id)
return ImageRequestOutput(
type=out_format, format="tiff", data=out_data, request_id=request_id
)

View File

@@ -16,12 +16,10 @@ class DataModuleConfig(TypedDict):
no_data_replace: float
no_label_replace: int
num_workers: int
test_transform: list[
albumentations.core.transforms_interface.BasicTransform]
test_transform: list[albumentations.core.transforms_interface.BasicTransform]
class ImagePrompt(BaseModel):
data_format: Literal["b64_json", "bytes", "url", "path"]
"""
This is the data type for the input image
@@ -45,7 +43,7 @@ MultiModalPromptType = Union[ImagePrompt]
class ImageRequestOutput(BaseModel):
"""
The output data of an image request to vLLM.
The output data of an image request to vLLM.
Args:
type (str): The data content type [path, object]