Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,8 +9,10 @@ import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
)
|
||||
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
@@ -36,7 +38,8 @@ def setup_environment():
|
||||
0,
|
||||
local_rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
backend="gloo")
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
yield
|
||||
|
||||
@@ -51,7 +54,7 @@ def _get_spmd_mesh():
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
mesh_shape = (num_devices, 1)
|
||||
device_ids = np.array(range(num_devices))
|
||||
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
|
||||
MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
||||
return MESH
|
||||
|
||||
|
||||
@@ -59,7 +62,7 @@ def _get_spmd_mesh():
|
||||
# `xr.use_spmd()` will set a global state, and this state is not reversible.
|
||||
# Therefore, non-SPMD tests should be run before SPMD tests.
|
||||
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
|
||||
@pytest.mark.parametrize("device", ['cpu', 'xla'])
|
||||
@pytest.mark.parametrize("device", ["cpu", "xla"])
|
||||
@torch.no_grad()
|
||||
def test_xla_qkv_linear(bias, mesh, device):
|
||||
torch.manual_seed(123)
|
||||
|
||||
Reference in New Issue
Block a user