Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -34,9 +34,9 @@ def run_radio_test(
|
||||
# Using `self.get_nearest_supported_resolution`, for assets 432x642 the
|
||||
# nearest supported resolution is 432x640.
|
||||
pixel_values = [
|
||||
img_processor(
|
||||
image,
|
||||
return_tensors='pt').pixel_values.to(torch_dtype)[:, :, :, :640]
|
||||
img_processor(image, return_tensors="pt").pixel_values.to(torch_dtype)[
|
||||
:, :, :, :640
|
||||
]
|
||||
for image in images
|
||||
]
|
||||
|
||||
@@ -51,32 +51,33 @@ def run_radio_test(
|
||||
hf_model.eval()
|
||||
|
||||
hf_outputs_per_image = [
|
||||
hf_model(pixel_value.to("cuda")).features
|
||||
for pixel_value in pixel_values
|
||||
hf_model(pixel_value.to("cuda")).features for pixel_value in pixel_values
|
||||
]
|
||||
|
||||
radio_config = RadioConfig(model_name=config.args["model"],
|
||||
reg_tokens=config.args["register_multiple"])
|
||||
radio_config = RadioConfig(
|
||||
model_name=config.args["model"], reg_tokens=config.args["register_multiple"]
|
||||
)
|
||||
vllm_model = RadioModel(radio_config)
|
||||
vllm_model.load_weights(hf_model.state_dict())
|
||||
vllm_model = vllm_model.to("cuda", torch_dtype)
|
||||
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model(pixel_values=pixel_value.to("cuda"))
|
||||
for pixel_value in pixel_values
|
||||
vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values
|
||||
]
|
||||
del vllm_model, hf_model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
cos_similar = nn.CosineSimilarity(dim=-1)
|
||||
for vllm_output, hf_output in zip(vllm_outputs_per_image,
|
||||
hf_outputs_per_image):
|
||||
for vllm_output, hf_output in zip(vllm_outputs_per_image, hf_outputs_per_image):
|
||||
assert cos_similar(vllm_output, hf_output).mean() > 0.99
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
"nvidia/C-RADIOv2-H",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"model_id",
|
||||
[
|
||||
"nvidia/C-RADIOv2-H",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_radio(dist_init, image_assets, model_id, dtype: str) -> None:
|
||||
run_radio_test(
|
||||
|
||||
Reference in New Issue
Block a user