Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -20,14 +20,15 @@ from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import (IOProcessorRequest,
|
||||
IOProcessorResponse)
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorRequest, IOProcessorResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors.interface import (IOProcessor,
|
||||
IOProcessorInput,
|
||||
IOProcessorOutput)
|
||||
from vllm.plugins.io_processors.interface import (
|
||||
IOProcessor,
|
||||
IOProcessorInput,
|
||||
IOProcessorOutput,
|
||||
)
|
||||
|
||||
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
|
||||
|
||||
@@ -42,35 +43,25 @@ DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5]
|
||||
|
||||
datamodule_config: DataModuleConfig = {
|
||||
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
|
||||
"batch_size":
|
||||
16,
|
||||
"constant_scale":
|
||||
0.0001,
|
||||
"data_root":
|
||||
"/dccstor/geofm-finetuning/datasets/sen1floods11",
|
||||
"drop_last":
|
||||
True,
|
||||
"no_data_replace":
|
||||
0.0,
|
||||
"no_label_replace":
|
||||
-1,
|
||||
"num_workers":
|
||||
8,
|
||||
"batch_size": 16,
|
||||
"constant_scale": 0.0001,
|
||||
"data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
|
||||
"drop_last": True,
|
||||
"no_data_replace": 0.0,
|
||||
"no_label_replace": -1,
|
||||
"num_workers": 8,
|
||||
"test_transform": [
|
||||
albumentations.Resize(always_apply=False,
|
||||
height=448,
|
||||
interpolation=1,
|
||||
p=1,
|
||||
width=448),
|
||||
albumentations.pytorch.ToTensorV2(transpose_mask=False,
|
||||
always_apply=True,
|
||||
p=1.0),
|
||||
albumentations.Resize(
|
||||
always_apply=False, height=448, interpolation=1, p=1, width=448
|
||||
),
|
||||
albumentations.pytorch.ToTensorV2(
|
||||
transpose_mask=False, always_apply=True, p=1.0
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def save_geotiff(image: torch.Tensor, meta: dict,
|
||||
out_format: str) -> str | bytes:
|
||||
def save_geotiff(image: torch.Tensor, meta: dict, out_format: str) -> str | bytes:
|
||||
"""Save multi-band image in Geotiff file.
|
||||
|
||||
Args:
|
||||
@@ -219,8 +210,11 @@ def load_image(
|
||||
if len(julian_day) == 3:
|
||||
julian_day = int(julian_day)
|
||||
else:
|
||||
julian_day = (datetime.datetime.strptime(
|
||||
julian_day, "%m%d").timetuple().tm_yday)
|
||||
julian_day = (
|
||||
datetime.datetime.strptime(julian_day, "%m%d")
|
||||
.timetuple()
|
||||
.tm_yday
|
||||
)
|
||||
temporal_coords.append([year, julian_day])
|
||||
except Exception:
|
||||
logger.exception("Could not extract timestamp for %s", file)
|
||||
@@ -233,11 +227,9 @@ def load_image(
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
|
||||
indices = [0, 1, 2, 3, 4, 5]
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.datamodule = Sen1Floods11NonGeoDataModule(
|
||||
@@ -264,8 +256,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
return image_prompt
|
||||
if isinstance(request, IOProcessorRequest):
|
||||
if not hasattr(request, "data"):
|
||||
raise ValueError(
|
||||
"missing 'data' field in OpenAIBaseModel Request")
|
||||
raise ValueError("missing 'data' field in OpenAIBaseModel Request")
|
||||
|
||||
request_data = request.data
|
||||
|
||||
@@ -277,7 +268,8 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
raise ValueError("Unable to parse request")
|
||||
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
|
||||
self, plugin_output: IOProcessorOutput
|
||||
) -> IOProcessorResponse:
|
||||
return IOProcessorResponse(
|
||||
request_id=plugin_output.request_id,
|
||||
data=plugin_output,
|
||||
@@ -289,7 +281,6 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
|
||||
image_data = dict(prompt)
|
||||
|
||||
if request_id:
|
||||
@@ -309,10 +300,8 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
input_data = input_data / 10000 # Convert to range 0-1
|
||||
|
||||
self.original_h, self.original_w = input_data.shape[-2:]
|
||||
pad_h = (self.img_size -
|
||||
(self.original_h % self.img_size)) % self.img_size
|
||||
pad_w = (self.img_size -
|
||||
(self.original_w % self.img_size)) % self.img_size
|
||||
pad_h = (self.img_size - (self.original_h % self.img_size)) % self.img_size
|
||||
pad_w = (self.img_size - (self.original_w % self.img_size)) % self.img_size
|
||||
input_data = np.pad(
|
||||
input_data,
|
||||
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
|
||||
@@ -320,9 +309,9 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
)
|
||||
|
||||
batch = torch.tensor(input_data)
|
||||
windows = batch.unfold(3, self.img_size,
|
||||
self.img_size).unfold(4, self.img_size,
|
||||
self.img_size)
|
||||
windows = batch.unfold(3, self.img_size, self.img_size).unfold(
|
||||
4, self.img_size, self.img_size
|
||||
)
|
||||
self.h1, self.w1 = windows.shape[3:5]
|
||||
windows = rearrange(
|
||||
windows,
|
||||
@@ -332,8 +321,11 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
)
|
||||
|
||||
# Split into batches if number of windows > batch_size
|
||||
num_batches = (windows.shape[0] // self.batch_size
|
||||
if windows.shape[0] > self.batch_size else 1)
|
||||
num_batches = (
|
||||
windows.shape[0] // self.batch_size
|
||||
if windows.shape[0] > self.batch_size
|
||||
else 1
|
||||
)
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
if temporal_coords:
|
||||
@@ -349,15 +341,18 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
for window in windows:
|
||||
# Apply standardization
|
||||
window = self.datamodule.test_transform(
|
||||
image=window.squeeze().numpy().transpose(1, 2, 0))
|
||||
image=window.squeeze().numpy().transpose(1, 2, 0)
|
||||
)
|
||||
window = self.datamodule.aug(window)["image"]
|
||||
prompts.append({
|
||||
"prompt_token_ids": [1],
|
||||
"multi_modal_data": {
|
||||
"pixel_values": window.to(torch.float16)[0],
|
||||
"location_coords": location_coords.to(torch.float16),
|
||||
},
|
||||
})
|
||||
prompts.append(
|
||||
{
|
||||
"prompt_token_ids": [1],
|
||||
"multi_modal_data": {
|
||||
"pixel_values": window.to(torch.float16)[0],
|
||||
"location_coords": location_coords.to(torch.float16),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return prompts
|
||||
|
||||
@@ -367,7 +362,6 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
|
||||
pred_imgs_list = []
|
||||
|
||||
if request_id and (request_id in self.requests_cache):
|
||||
@@ -399,7 +393,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
)
|
||||
|
||||
# Cut padded area back to original size
|
||||
pred_imgs = pred_imgs[..., :self.original_h, :self.original_w]
|
||||
pred_imgs = pred_imgs[..., : self.original_h, : self.original_w]
|
||||
|
||||
# Squeeze (batch size 1)
|
||||
pred_imgs = pred_imgs[0]
|
||||
@@ -407,10 +401,10 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
if not self.meta_data:
|
||||
raise ValueError("No metadata available for the current task")
|
||||
self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data,
|
||||
out_format)
|
||||
out_data = save_geotiff(
|
||||
_convert_np_uint8(pred_imgs), self.meta_data, out_format
|
||||
)
|
||||
|
||||
return ImageRequestOutput(type=out_format,
|
||||
format="tiff",
|
||||
data=out_data,
|
||||
request_id=request_id)
|
||||
return ImageRequestOutput(
|
||||
type=out_format, format="tiff", data=out_data, request_id=request_id
|
||||
)
|
||||
|
||||
@@ -16,12 +16,10 @@ class DataModuleConfig(TypedDict):
|
||||
no_data_replace: float
|
||||
no_label_replace: int
|
||||
num_workers: int
|
||||
test_transform: list[
|
||||
albumentations.core.transforms_interface.BasicTransform]
|
||||
test_transform: list[albumentations.core.transforms_interface.BasicTransform]
|
||||
|
||||
|
||||
class ImagePrompt(BaseModel):
|
||||
|
||||
data_format: Literal["b64_json", "bytes", "url", "path"]
|
||||
"""
|
||||
This is the data type for the input image
|
||||
@@ -45,7 +43,7 @@ MultiModalPromptType = Union[ImagePrompt]
|
||||
|
||||
class ImageRequestOutput(BaseModel):
|
||||
"""
|
||||
The output data of an image request to vLLM.
|
||||
The output data of an image request to vLLM.
|
||||
|
||||
Args:
|
||||
type (str): The data content type [path, object]
|
||||
|
||||
Reference in New Issue
Block a user