[Bugfix] Refactor Flashinfer TRTLLM attention kernel selection logic (#24600)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
elvischenv
2025-09-18 06:36:29 +08:00
committed by GitHub
parent 9f882d8791
commit e67a79db03
3 changed files with 65 additions and 29 deletions

View File

@@ -282,7 +282,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype
if supports_trtllm_attention()[0] and \
# Use model dtype as q dtype when TRTLLM attn is not supported, or
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
# use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
if supports_trtllm_attention() and \
not flashinfer_disable_q_quantization():
self.q_data_type = self.kv_cache_dtype
else:
@@ -298,7 +302,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.window_left = self.global_hyperparameters.window_left
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
self.has_sinks = self.global_hyperparameters.has_sinks
if self.has_sinks and not supports_trtllm_attention()[0]:
if self.has_sinks and not supports_trtllm_attention():
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention on "
@@ -477,14 +481,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_last_page_len_np,
)
# Check if any layer uses sinks (requires TRTLLM attention)
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
num_prefill_tokens,
max_seq_len,
self.cache_dtype,
self.q_data_type,
is_prefill=True,
has_sinks=self.has_sinks)
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads,
@@ -492,13 +494,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
max_seq_len,
self.cache_dtype,
self.q_data_type,
is_prefill=False,
has_sinks=self.has_sinks)
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
raise NotImplementedError(
"FlashInfer backend currently does not support attention "
"sinks, please use trtllm on blackwell or flash attention on "
"earlier GPUs.")
# If TRTLLM attention is not used, the q quantization is not supported.
# Fall back to use model dtype.
if not (prefill_use_trtllm and decode_use_trtllm):
self.q_data_type = self.model_config.dtype
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
q_data_type=self.q_data_type,