Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -14,10 +14,8 @@ from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
@@ -28,7 +26,6 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestSiluMul(torch.nn.Module):
def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
@@ -36,8 +33,7 @@ class TestSiluMul(torch.nn.Module):
self.scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8:
self.w = torch.rand(hidden_size,
hidden_size).to(dtype=FP8_DTYPE).t()
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
@@ -46,17 +42,14 @@ class TestSiluMul(torch.nn.Module):
def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y,
self.w,
self.wscale,
input_scale=self.wscale)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
else:
return y
def example_inputs(self, num_tokens=32, hidden_size=128):
dtype = torch.float16 if TEST_FP8 else torch.float32
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
@@ -69,7 +62,6 @@ class TestSiluMul(torch.nn.Module):
class TestFusedAddRMSNorm(torch.nn.Module):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
@@ -78,10 +70,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
dtype = torch.float16 if TEST_FP8 else torch.float32
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size), dtype=dtype))
torch.empty((intermediate_size, hidden_size), dtype=dtype)
)
self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter(
torch.ones(intermediate_size, dtype=dtype))
torch.ones(intermediate_size, dtype=dtype)
)
torch.nn.init.normal_(self.gate_proj, std=0.02)
@@ -89,8 +83,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size,
intermediate_size).to(dtype=FP8_DTYPE).t()
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
def forward(self, hidden_states, residual):
@@ -120,10 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
dtype = torch.float16 if TEST_FP8 else torch.float32
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
return (hidden_states, residual)
def ops_in_model(self, do_fusion):
@@ -137,12 +128,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
class TestRotaryEmbedding(torch.nn.Module):
def __init__(self,
head_dim=64,
rotary_dim=None,
max_position=2048,
base=10000):
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.rotary_dim = rotary_dim or head_dim
@@ -173,21 +159,15 @@ class TestRotaryEmbedding(torch.nn.Module):
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
def __init__(self,
head_dim=64,
num_heads=4,
max_position=2048,
base=10000):
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
super().__init__()
self.head_dim = head_dim
self.num_heads = num_heads
self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear(self.hidden_size,
self.hidden_size * 3,
bias=False,
dtype=torch.float16)
self.qkv_proj = torch.nn.Linear(
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -233,21 +213,24 @@ MODELS = [
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
torch.set_default_device("cuda")
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion else [noop_pass, cleanup_pass])
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
@@ -260,8 +243,7 @@ def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
is None) # noqa: E501
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
# make sure the ops were all de-functionalized
found = dict()