[Bugfix] Refactor Flashinfer TRTLLM attention kernel selection logic (#24600)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -282,7 +282,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
if supports_trtllm_attention()[0] and \
|
||||
# Use model dtype as q dtype when TRTLLM attn is not supported, or
|
||||
# VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to
|
||||
# use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
if supports_trtllm_attention() and \
|
||||
not flashinfer_disable_q_quantization():
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
@@ -298,7 +302,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.window_left = self.global_hyperparameters.window_left
|
||||
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
||||
self.has_sinks = self.global_hyperparameters.has_sinks
|
||||
if self.has_sinks and not supports_trtllm_attention()[0]:
|
||||
if self.has_sinks and not supports_trtllm_attention():
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
"sinks, please use trtllm on blackwell or flash attention on "
|
||||
@@ -477,14 +481,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
num_prefill_tokens,
|
||||
max_seq_len,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=True,
|
||||
has_sinks=self.has_sinks)
|
||||
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||
self.num_kv_heads,
|
||||
@@ -492,13 +494,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
max_seq_len,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=False,
|
||||
has_sinks=self.has_sinks)
|
||||
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
|
||||
raise NotImplementedError(
|
||||
"FlashInfer backend currently does not support attention "
|
||||
"sinks, please use trtllm on blackwell or flash attention on "
|
||||
"earlier GPUs.")
|
||||
|
||||
# If TRTLLM attention is not used, the q quantization is not supported.
|
||||
# Fall back to use model dtype.
|
||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||
self.q_data_type = self.model_config.dtype
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
q_data_type=self.q_data_type,
|
||||
|
||||
Reference in New Issue
Block a user