[bugfix] Fix Llama3/4 issues caused by FlashInfer 0.2.10 (#22426)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
157f9c1368
commit
af473f0a85
@@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# currently prefill trtllm attention does not support fp8 kv cache
|
||||
prefill_use_trtllm = use_trtllm_attention(
|
||||
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
|
||||
and use_trtllm_attention(
|
||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim)
|
||||
decode_use_trtllm = use_trtllm_attention(
|
||||
|
||||
Reference in New Issue
Block a user