[Misc] Enable yapf for FlashInfer backend (#23193)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -36,6 +36,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|||||||
get_per_layer_parameters,
|
get_per_layer_parameters,
|
||||||
infer_global_hyperparameters,
|
infer_global_hyperparameters,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
|
# yapf: enable
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||||
@@ -541,12 +542,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
if cache_dtype.startswith("fp8") and enable_fusion:
|
if cache_dtype.startswith("fp8") and enable_fusion:
|
||||||
q_dtype = kv_cache_dtype
|
q_dtype = kv_cache_dtype
|
||||||
|
|
||||||
prefill_use_trtllm = use_trtllm_attention(
|
prefill_use_trtllm = use_trtllm_attention(num_qo_heads,
|
||||||
num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len,
|
num_kv_heads,
|
||||||
cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks)
|
num_prefill_tokens,
|
||||||
decode_use_trtllm = use_trtllm_attention(
|
max_seq_len,
|
||||||
num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len,
|
cache_dtype,
|
||||||
cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks)
|
q_dtype,
|
||||||
|
is_prefill=True,
|
||||||
|
has_sinks=has_sinks)
|
||||||
|
decode_use_trtllm = use_trtllm_attention(num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_decode_tokens,
|
||||||
|
max_seq_len,
|
||||||
|
cache_dtype,
|
||||||
|
q_dtype,
|
||||||
|
is_prefill=False,
|
||||||
|
has_sinks=has_sinks)
|
||||||
|
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
@@ -654,19 +665,18 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sinks must have the same number of heads as the number of "
|
"Sinks must have the same number of heads as the number of "
|
||||||
f"heads in the layer. Expected {num_heads}, but got "
|
f"heads in the layer. Expected {num_heads}, but got "
|
||||||
f"{sinks.shape[0]}."
|
f"{sinks.shape[0]}.")
|
||||||
)
|
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
|
|
||||||
self.support_trtllm_attn = (supports_trtllm_attention() and
|
self.support_trtllm_attn = (supports_trtllm_attention()
|
||||||
num_heads % num_kv_heads == 0)
|
and num_heads % num_kv_heads == 0)
|
||||||
self.bmm1_scale: Optional[float] = None
|
self.bmm1_scale: Optional[float] = None
|
||||||
self.bmm2_scale: Optional[float] = None
|
self.bmm2_scale: Optional[float] = None
|
||||||
|
|
||||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||||
group_shape: GroupShape):
|
group_shape: GroupShape):
|
||||||
supported_quant_type = (dtype == FP8_DTYPE and static and
|
supported_quant_type = (dtype == FP8_DTYPE and static
|
||||||
group_shape == GroupShape.PER_TENSOR)
|
and group_shape == GroupShape.PER_TENSOR)
|
||||||
return (self.support_trtllm_attn
|
return (self.support_trtllm_attn
|
||||||
and self.kv_cache_dtype.startswith("fp8")
|
and self.kv_cache_dtype.startswith("fp8")
|
||||||
and supported_quant_type)
|
and supported_quant_type)
|
||||||
@@ -731,7 +741,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
# Insert FP8 quant for query
|
# Insert FP8 quant for query
|
||||||
num_tokens, num_heads, head_size = query.shape
|
num_tokens, num_heads, head_size = query.shape
|
||||||
query, _ = ops.scaled_fp8_quant(
|
query, _ = ops.scaled_fp8_quant(
|
||||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
query.reshape(
|
||||||
|
(num_tokens, num_heads * head_size)).contiguous(),
|
||||||
layer._q_scale)
|
layer._q_scale)
|
||||||
query = query.reshape((num_tokens, num_heads, head_size))
|
query = query.reshape((num_tokens, num_heads, head_size))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user