Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -19,14 +19,23 @@ from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
from vllm.config import (
CacheConfig,
CompilationConfig,
CompilationLevel,
ModelConfig,
PassConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -40,14 +49,16 @@ backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
)
@pytest.mark.parametrize("use_triton_fa", [True, False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="V0 attn quant fusion only on ROCm")
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
)
def test_attention_fusion_v0(
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
@@ -69,22 +80,24 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
vllm_config = VllmConfig(
compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
),
)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048)
llm = LLM(
model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048,
)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=10,
top_p=0.95)
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
@@ -97,21 +110,25 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
vllm_config = VllmConfig(
compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
),
)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048)
llm2 = LLM(
model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048,
)
# check support
attn_fusion_supported = [
@@ -132,9 +149,9 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], \
f"Node {i} {'' if fused else 'not '} expected " \
f"to have fused output quant"
assert fused == attn_fusion_supported[i], (
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
)
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
@@ -160,9 +177,16 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion."""
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype, device: torch.device,
vllm_config: VllmConfig, **kwargs):
def __init__(
self,
num_qo_heads: int,
num_kv_heads: int,
head_size: int,
kv_cache_dtype: torch.dtype,
device: torch.device,
vllm_config: VllmConfig,
**kwargs,
):
super().__init__()
self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads
@@ -197,33 +221,30 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device,
)
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
-> AttentionMetadata:
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
"""Initialize attention metadata."""
# Create common attn metadata
batch_spec = BatchSpec(seq_lens=[1] * batch_size,
query_lens=[1] * batch_size)
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
self.block_size,
self.device,
arange_block_indices=True)
batch_spec, self.block_size, self.device, arange_block_indices=True
)
max_blocks = (max(batch_spec.seq_lens) + self.block_size -
1) // self.block_size
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks
# Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device)
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
if current_platform.is_rocm():
# k/v as 1st dimention
if use_hnd:
@@ -239,7 +260,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
# Build attn metadata
self.attn_metadata = self.builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata)
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
return self.attn_metadata
@@ -254,27 +276,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape)
act_quant_group_shape=self.quant_key.scale.group_shape,
)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randn(hidden_size, hidden_size).to(
dtype=FP8_DTYPE, device=self.device).t(),
"wscale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
})
"w",
{
"weight": torch.randn(hidden_size, hidden_size)
.to(dtype=FP8_DTYPE, device=self.device)
.t(),
"wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
},
)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(input=attn_output,
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"])
return self.fp8_linear.apply(
input=attn_output,
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"],
)
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
@@ -287,42 +312,54 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randint(256, (hidden_size, hidden_size // 2),
dtype=FP4_DTYPE,
device=self.device),
"wscale_swizzled":
torch.randn(hidden_size, hidden_size // 16).to(
dtype=FP8_DTYPE, device=self.device),
"wscale":
torch.tensor([500], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([0.002], dtype=torch.float32, device=self.device),
})
"w",
{
"weight": torch.randint(
256,
(hidden_size, hidden_size // 2),
dtype=FP4_DTYPE,
device=self.device,
),
"wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to(
dtype=FP8_DTYPE, device=self.device
),
"wscale": torch.tensor([500], dtype=torch.float32, device=self.device),
"scale": torch.tensor([0.002], dtype=torch.float32, device=self.device),
},
)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
quant_output, output_block_scale = scaled_fp4_quant(
attn_output, 1 / self.w["scale"])
return cutlass_scaled_fp4_mm(a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype)
attn_output, 1 / self.w["scale"]
)
return cutlass_scaled_fp4_mm(
a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype,
)
if current_platform.is_cuda():
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)]
MODELS = [
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel,
),
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel,
),
]
HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm():
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
TestAttentionFp8StaticQuantPatternModel)]
MODELS = [
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
HEADS = [(32, 8), (40, 8)]
else:
MODELS = []
@@ -331,41 +368,53 @@ else:
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size",
[7, 256, 533] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize(
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize("backend",
[_Backend.FLASHINFER] if current_platform.is_cuda()
else [_Backend.TRITON_ATTN])
@pytest.mark.parametrize(
"split_attention",
[False, True] if current_platform.is_rocm() else [False])
"backend",
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
)
@pytest.mark.parametrize(
"split_attention", [False, True] if current_platform.is_rocm() else [False]
)
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
[False] if current_platform.is_rocm() else [False, True],
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(current_platform.is_cuda()
and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend, split_attention: bool,
use_inductor_graph_partition: bool,
monkeypatch, dist_init, caplog_vllm):
@pytest.mark.skipif(
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)",
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
def test_attention_quant_pattern(
num_qo_heads: int,
num_kv_heads: int,
head_size: int,
batch_size: int,
dtype: torch.dtype,
model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend,
split_attention: bool,
use_inductor_graph_partition: bool,
monkeypatch,
dist_init,
caplog_vllm,
):
"""Test AttentionStaticQuantPattern fusion pass"""
if use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
@@ -386,21 +435,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
),
cache_config=CacheConfig(cache_dtype="fp8"))
cache_config=CacheConfig(cache_dtype="fp8"),
)
# Create test inputs
q = torch.randn(batch_size,
num_qo_heads * head_size,
dtype=dtype,
device=device)
k = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
v = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
device=device)
q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device)
k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device)
# Mark first dimension as dynamic for realistic testing
torch._dynamo.mark_dynamic(q, 0)
@@ -409,42 +450,53 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# Run model directly without compilation and fusion
vllm_config_unfused = copy.deepcopy(vllm_config)
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
attn_metadata=None, vllm_config=vllm_config_unfused
), global_force_attn_backend_context_manager(backend):
model_unfused = model_class(num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused)
with (
set_current_vllm_config(vllm_config_unfused),
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
global_force_attn_backend_context_manager(backend),
):
model_unfused = model_class(
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused,
)
model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size, use_hnd=split_attention)
batch_size, use_hnd=split_attention
)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
enable_attn_fusion=True, enable_noop=True)
with set_current_vllm_config(vllm_config), set_forward_context(
attn_metadata=None, vllm_config=vllm_config
), global_force_attn_backend_context_manager(backend):
model_fused = model_class(num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w)
enable_attn_fusion=True, enable_noop=True
)
with (
set_current_vllm_config(vllm_config),
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
global_force_attn_backend_context_manager(backend),
):
model_fused = model_class(
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w,
)
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention)
batch_size, use_hnd=split_attention
)
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
@@ -454,9 +506,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
# Compile model with fusion enabled
model_compiled = torch.compile(model_fused,
backend=test_backend,
fullgraph=True)
model_compiled = torch.compile(
model_fused, backend=test_backend, fullgraph=True
)
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v)
@@ -471,49 +523,49 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert model_compiled.attn._o_scale_float is not None
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(
result_unfused, result_fused_2, atol=1e-2, rtol=1e-2
)
# Check attn fusion support
quant_key = model_class.quant_key
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
vllm_config.compilation_config.static_forward_context.items()
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in vllm_config.compilation_config.static_forward_context.items()
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]],
fully_replaced=True)
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP,
test_backend.graph_post_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass))
assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion"
assert len(attn_nodes_pre) == len(attn_nodes_post), \
assert len(attn_nodes_pre) == len(attn_nodes_post), (
"Should have same number of attention nodes before and after fusion"
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \
)
assert attn_nodes_pre[0].kwargs.get("output_scale") is None, (
"Attention should not have output_scale before fusion"
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
)
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, (
"Attention should have output_scale after fusion"
)
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale before fusion"
)
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale after FP8 fusion"
)
elif quant_key.dtype == FP4_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, (
"Attention should have output_block_scale after FP4 fusion"
) # noqa: E501
# Check that results are close
torch.testing.assert_close(result_unfused,
result_fused_1,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2)