Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,8 +9,10 @@ import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.linear import QKVParallelLinear
@@ -36,7 +38,8 @@ def setup_environment():
0,
local_rank=0,
distributed_init_method=f"file://{temp_file}",
backend="gloo")
backend="gloo",
)
ensure_model_parallel_initialized(1, 1)
yield
@@ -51,7 +54,7 @@ def _get_spmd_mesh():
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
return MESH
@@ -59,7 +62,7 @@ def _get_spmd_mesh():
# `xr.use_spmd()` will set a global state, and this state is not reversible.
# Therefore, non-SPMD tests should be run before SPMD tests.
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
@pytest.mark.parametrize("device", ['cpu', 'xla'])
@pytest.mark.parametrize("device", ["cpu", "xla"])
@torch.no_grad()
def test_xla_qkv_linear(bias, mesh, device):
torch.manual_seed(123)