Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,10 +8,15 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass)
|
||||
FUSED_OPS,
|
||||
SILU_MUL_OP,
|
||||
ActivationQuantFusionPass,
|
||||
)
|
||||
|
||||
# yapf: enable
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
@@ -19,9 +24,14 @@ from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
GroupShape,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, cutlass_fp8_supported)
|
||||
Fp8LinearOp,
|
||||
cutlass_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
@@ -36,7 +46,6 @@ def is_nvfp4_supported():
|
||||
|
||||
|
||||
class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@@ -53,10 +62,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.wscale)
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@@ -67,11 +73,12 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||
super().__init__()
|
||||
from vllm.compilation.activation_quant_fusion import (
|
||||
silu_and_mul_nvfp4_quant_supported)
|
||||
silu_and_mul_nvfp4_quant_supported,
|
||||
)
|
||||
|
||||
assert silu_and_mul_nvfp4_quant_supported
|
||||
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
@@ -88,12 +95,14 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
|
||||
out = cutlass_scaled_fp4_mm(a=y_quant,
|
||||
b=self.w,
|
||||
block_scale_a=y_block_scale,
|
||||
block_scale_b=self.w_block_scale,
|
||||
alpha=self.alpha,
|
||||
out_dtype=y.dtype)
|
||||
out = cutlass_scaled_fp4_mm(
|
||||
a=y_quant,
|
||||
b=self.w,
|
||||
block_scale_a=y_block_scale,
|
||||
block_scale_b=self.w_block_scale,
|
||||
alpha=self.alpha,
|
||||
out_dtype=y.dtype,
|
||||
)
|
||||
return out
|
||||
|
||||
def ops_in_model_before(self):
|
||||
@@ -108,16 +117,24 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"model_class",
|
||||
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]))
|
||||
cast(
|
||||
list[type],
|
||||
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||
if is_nvfp4_supported()
|
||||
else [TestSiluMulFp8QuantModel],
|
||||
),
|
||||
)
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize("cuda_force_torch",
|
||||
[True, False] if cutlass_fp8_supported() else [True])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
cuda_force_torch):
|
||||
@pytest.mark.parametrize(
|
||||
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
def test_fusion_silu_and_mul_quant(
|
||||
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
|
||||
):
|
||||
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
|
||||
pytest.skip("Duplicate tests for NVFP4")
|
||||
|
||||
@@ -129,17 +146,13 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = VllmConfig()
|
||||
config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
|
||||
)
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
passes = [
|
||||
NoOpEliminationPass(config), fusion_pass,
|
||||
PostCleanupPass(config)
|
||||
]
|
||||
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
|
||||
backend = TestBackend(*passes)
|
||||
model = model_class(hidden_size=hidden_size,
|
||||
cuda_force_torch=cuda_force_torch,
|
||||
x=x)
|
||||
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
|
||||
|
||||
# First dimension dynamic
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
@@ -155,10 +168,9 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
|
||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
|
||||
torch.testing.assert_close(result[0].to(dtype=dtype),
|
||||
result2[0].to(dtype=dtype),
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user