Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -14,10 +14,8 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -28,7 +26,6 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class TestSiluMul(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int = 128):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@@ -36,8 +33,7 @@ class TestSiluMul(torch.nn.Module):
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
if TEST_FP8:
|
||||
self.w = torch.rand(hidden_size,
|
||||
hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
@@ -46,17 +42,14 @@ class TestSiluMul(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
if TEST_FP8:
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.wscale)
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
else:
|
||||
return y
|
||||
|
||||
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )
|
||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
if TEST_FP8 and do_fusion:
|
||||
@@ -69,7 +62,6 @@ class TestSiluMul(torch.nn.Module):
|
||||
|
||||
|
||||
class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -78,10 +70,12 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size), dtype=dtype))
|
||||
torch.empty((intermediate_size, hidden_size), dtype=dtype)
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
self.norm.weight = torch.nn.Parameter(
|
||||
torch.ones(intermediate_size, dtype=dtype))
|
||||
torch.ones(intermediate_size, dtype=dtype)
|
||||
)
|
||||
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
@@ -89,8 +83,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.w = torch.rand(hidden_size,
|
||||
intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
@@ -120,10 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
return (hidden_states, residual)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@@ -137,12 +128,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
head_dim=64,
|
||||
rotary_dim=None,
|
||||
max_position=2048,
|
||||
base=10000):
|
||||
def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.rotary_dim = rotary_dim or head_dim
|
||||
@@ -173,21 +159,15 @@ class TestRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
|
||||
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
head_dim=64,
|
||||
num_heads=4,
|
||||
max_position=2048,
|
||||
base=10000):
|
||||
def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = head_dim * num_heads
|
||||
|
||||
self.qkv_proj = torch.nn.Linear(self.hidden_size,
|
||||
self.hidden_size * 3,
|
||||
bias=False,
|
||||
dtype=torch.float16)
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -233,21 +213,24 @@ MODELS = [
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODELS)
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
|
||||
)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion else [noop_pass, cleanup_pass])
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
@@ -260,8 +243,7 @@ def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
|
||||
is None) # noqa: E501
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # noqa: E501
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
|
||||
Reference in New Issue
Block a user