Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,14 +6,12 @@ import torch
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("num_tokens", [256, 1024])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 4096])
|
||||
def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
@@ -22,7 +20,6 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
torch.manual_seed(1)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
# Chain of reshapes
|
||||
y = x.reshape(-1, 128, 32)
|
||||
@@ -32,7 +29,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
# Final reshape that should remain
|
||||
b = a.reshape(-1, 128, 32)
|
||||
# No-op slice
|
||||
c = b[0:b.shape[0]]
|
||||
c = b[0 : b.shape[0]]
|
||||
# The pass should replace the result of this op with `c`.
|
||||
d = torch.slice_scatter(
|
||||
torch.ones_like(c), # Dummy tensor to be scattered into
|
||||
@@ -43,10 +40,12 @@ def test_noop_elimination(dtype, num_tokens, hidden_size):
|
||||
)
|
||||
return d
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
||||
@@ -82,17 +81,18 @@ def test_non_noop_slice_preserved():
|
||||
x = torch.randn(16, 16)
|
||||
|
||||
class SliceModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
base = x.clone()
|
||||
src = torch.ones(15, 16)
|
||||
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
|
||||
return x[0:-1, :], y
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
pass_config=PassConfig(enable_noop=True),
|
||||
)
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
backend = TestBackend(noop_pass)
|
||||
|
||||
Reference in New Issue
Block a user