[bugfix] Fix Llama3/4 issues caused by FlashInfer 0.2.10 (#22426)

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
Po-Han Huang (NVIDIA)
2025-08-08 11:25:01 +08:00
committed by GitHub
parent 157f9c1368
commit af473f0a85
2 changed files with 17 additions and 8 deletions

View File

@@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
head_dim = self.kv_cache_spec.head_size
# currently prefill trtllm attention does not support fp8 kv cache
prefill_use_trtllm = use_trtllm_attention(
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
and use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim)
decode_use_trtllm = use_trtllm_attention(