Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -1,24 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""High-Performance Triton-only Attention layer."""
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionType,
)
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash)
triton_reshape_and_cache_flash,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym)
QuantKey,
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if current_platform.is_cuda_alike():
@@ -59,21 +69,25 @@ class TritonAttentionMetadata:
prefix_scheduler_metadata: Optional[torch.Tensor] = None
class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
vllm_config.parallel_config)
vllm_config.parallel_config
)
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
def build_for_cudagraph_capture(
@@ -86,10 +100,12 @@ class TritonAttentionMetadataBuilder(
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> TritonAttentionMetadata:
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TritonAttentionMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
@@ -102,14 +118,13 @@ class TritonAttentionMetadataBuilder(
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (common_attn_metadata.seq_lens_cpu -
common_prefix_len)
cu_prefix_query_lens = torch.tensor(
[0, num_actual_tokens], dtype=torch.int32, device=self.device
)
prefix_kv_lens = torch.tensor(
[common_prefix_len], dtype=torch.int32, device=self.device
)
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
suffix_kv_lens = suffix_kv_lens.to(self.device)
else:
cu_prefix_query_lens = None
@@ -136,7 +151,6 @@ class TritonAttentionMetadataBuilder(
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
@@ -151,7 +165,8 @@ class TritonAttentionBackend(AttentionBackend):
f"Head size {head_size} is not supported by TritonAttention."
f"Head sizes need to be larger or equal 32 for this backend. "
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes.")
"FlexAttention backend which supports all head sizes."
)
@staticmethod
def get_name() -> str:
@@ -187,7 +202,6 @@ class TritonAttentionBackend(AttentionBackend):
class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
@@ -228,10 +242,12 @@ class TritonAttentionImpl(AttentionImpl):
TritonAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl")
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()
@@ -240,7 +256,8 @@ class TritonAttentionImpl(AttentionImpl):
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
f"num_heads: {num_heads}."
)
def forward(
self,
@@ -271,7 +288,8 @@ class TritonAttentionImpl(AttentionImpl):
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for TritonAttentionImpl")
" for TritonAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
@@ -316,16 +334,17 @@ class TritonAttentionImpl(AttentionImpl):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, \
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc