Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,14 +9,18 @@ import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.model_loader.tpu import TPUModelLoader
def _setup_environment(model):
engine_args = EngineArgs(model=model, )
engine_args = EngineArgs(
model=model,
)
vllm_config = engine_args.create_engine_config()
with set_current_vllm_config(vllm_config):
temp_file = tempfile.mkstemp()[1]
@@ -25,7 +29,8 @@ def _setup_environment(model):
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo")
backend="gloo",
)
# Under single worker mode, full model is init first and then
# partitioned using GSPMD.
ensure_model_parallel_initialized(1, 1)
@@ -42,7 +47,7 @@ def _get_spmd_mesh():
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
return MESH
@@ -53,15 +58,17 @@ def _get_spmd_mesh():
# Skip large models due to CI runner disk space limitations
# "meta-llama/Llama-3.1-8B-Instruct",
# "meta-llama/Llama-3.1-70B-Instruct",
])
],
)
def test_tpu_model_loader(model):
# Skip the 70B test if there are less than 8 chips
# TODO: Query using torch xla API, the query API is not working
# with SPMD now. However, This test is running under SPMD mode.
if '70B' in model and xr.global_runtime_device_count() < 8:
if "70B" in model and xr.global_runtime_device_count() < 8:
pytest.skip(
"Skipping 70B model if the TPU VM has less than 8 chips to \
avoid OOM.")
avoid OOM."
)
vllm_config = _setup_environment(model)
loader = TPUModelLoader(load_config=vllm_config.load_config)