[V1] Enable V1 Fp8 cache for FA3 in the oracle (#15191)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-03-23 18:07:04 -04:00
committed by GitHub
parent 9c5c81b0da
commit dccf535f8e
9 changed files with 45 additions and 23 deletions

View File

@@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
@@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"FlashAttentionImpl")
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
def forward(
self,

View File

@@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
@@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func