[Misc] Enable yapf for FlashInfer backend (#23193)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-19 10:33:47 -07:00
committed by GitHub
parent f7cf5b512e
commit 5b5f350d67

View File

@@ -36,6 +36,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
get_per_layer_parameters, get_per_layer_parameters,
infer_global_hyperparameters, infer_global_hyperparameters,
split_decodes_and_prefills) split_decodes_and_prefills)
# yapf: enable
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
@@ -541,12 +542,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if cache_dtype.startswith("fp8") and enable_fusion: if cache_dtype.startswith("fp8") and enable_fusion:
q_dtype = kv_cache_dtype q_dtype = kv_cache_dtype
prefill_use_trtllm = use_trtllm_attention( prefill_use_trtllm = use_trtllm_attention(num_qo_heads,
num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len, num_kv_heads,
cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks) num_prefill_tokens,
decode_use_trtllm = use_trtllm_attention( max_seq_len,
num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len, cache_dtype,
cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks) q_dtype,
is_prefill=True,
has_sinks=has_sinks)
decode_use_trtllm = use_trtllm_attention(num_qo_heads,
num_kv_heads,
num_decode_tokens,
max_seq_len,
cache_dtype,
q_dtype,
is_prefill=False,
has_sinks=has_sinks)
attn_metadata = FlashInferMetadata( attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
@@ -654,19 +665,18 @@ class FlashInferImpl(AttentionImpl):
raise ValueError( raise ValueError(
"Sinks must have the same number of heads as the number of " "Sinks must have the same number of heads as the number of "
f"heads in the layer. Expected {num_heads}, but got " f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}." f"{sinks.shape[0]}.")
)
self.sinks = sinks self.sinks = sinks
self.support_trtllm_attn = (supports_trtllm_attention() and self.support_trtllm_attn = (supports_trtllm_attention()
num_heads % num_kv_heads == 0) and num_heads % num_kv_heads == 0)
self.bmm1_scale: Optional[float] = None self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None self.bmm2_scale: Optional[float] = None
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape): group_shape: GroupShape):
supported_quant_type = (dtype == FP8_DTYPE and static and supported_quant_type = (dtype == FP8_DTYPE and static
group_shape == GroupShape.PER_TENSOR) and group_shape == GroupShape.PER_TENSOR)
return (self.support_trtllm_attn return (self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8") and self.kv_cache_dtype.startswith("fp8")
and supported_quant_type) and supported_quant_type)
@@ -731,7 +741,8 @@ class FlashInferImpl(AttentionImpl):
# Insert FP8 quant for query # Insert FP8 quant for query
num_tokens, num_heads, head_size = query.shape num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant( query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(), query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale) layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size)) query = query.reshape((num_tokens, num_heads, head_size))