[flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention (#24197)
Signed-off-by: Xiaozhu <mxz297@gmail.com>
This commit is contained in:
@@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
||||
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization,
|
||||
supports_trtllm_attention,
|
||||
use_trtllm_attention)
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
# yapf conflicts with isort for this block
|
||||
@@ -48,8 +49,89 @@ FP4_DTYPE = torch.uint8
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
@triton.jit
|
||||
def _trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache_ptr,
|
||||
block_tables_prefill_ptr,
|
||||
block_table_stride,
|
||||
mock_kv_cache_ptr,
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
K_CACHE_STRIDE: tl.constexpr,
|
||||
KV_CACHE_STRIDE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0).to(tl.int64)
|
||||
mock_block_table_idx = tl.program_id(1).to(tl.int64)
|
||||
orig_page_num = tl.load(block_tables_prefill_ptr +
|
||||
batch_idx * block_table_stride +
|
||||
mock_block_table_idx).to(tl.int64)
|
||||
if orig_page_num <= 0:
|
||||
return
|
||||
dequant_dtype = mock_kv_cache_ptr.dtype.element_ty
|
||||
|
||||
# Dequantize K
|
||||
k_scale_val = tl.load(k_scale_ptr)
|
||||
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
|
||||
fp8_vals = tl.load(kv_cache_ptr + offset)
|
||||
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
|
||||
mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx
|
||||
+ 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
|
||||
dequantized_vals = dequantized_vals.to(dequant_dtype)
|
||||
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
|
||||
|
||||
# Dequantize V
|
||||
v_scale_val = tl.load(v_scale_ptr)
|
||||
offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE +
|
||||
tl.arange(0, K_CACHE_STRIDE))
|
||||
fp8_vals = tl.load(kv_cache_ptr + offset)
|
||||
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
|
||||
mock_cache_offset = (
|
||||
(batch_idx * block_table_stride + mock_block_table_idx + 1) *
|
||||
KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE))
|
||||
dequantized_vals = dequantized_vals.to(dequant_dtype)
|
||||
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
|
||||
|
||||
|
||||
def trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache: torch.Tensor,
|
||||
block_tables_prefill: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
dequant_dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_of_page_per_token = block_tables_prefill.shape
|
||||
s = kv_cache.shape
|
||||
assert s[1] == 2
|
||||
assert dequant_dtype in (torch.bfloat16, torch.float16)
|
||||
k_cache_stride = s[2] * s[3] * s[4]
|
||||
kv_cache_stride = k_cache_stride * s[1]
|
||||
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
|
||||
# mock kv cache contains just the pages needed by this prefill
|
||||
mock_kv_cache = torch.empty(new_s,
|
||||
dtype=dequant_dtype,
|
||||
device=kv_cache.device)
|
||||
# we simply sequentially index the pages needed by this prefill
|
||||
mock_block_table = torch.arange(
|
||||
start=1,
|
||||
end=batch_size * num_of_page_per_token + 1,
|
||||
dtype=torch.int32,
|
||||
device=block_tables_prefill.device,
|
||||
).reshape(batch_size, num_of_page_per_token)
|
||||
grid = (batch_size, num_of_page_per_token)
|
||||
_trtllm_prefill_attn_kvfp8_dequant[grid](
|
||||
kv_cache,
|
||||
block_tables_prefill,
|
||||
num_of_page_per_token,
|
||||
mock_kv_cache,
|
||||
k_scale,
|
||||
v_scale,
|
||||
k_cache_stride,
|
||||
kv_cache_stride,
|
||||
)
|
||||
return mock_kv_cache, mock_block_table
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
@@ -122,7 +204,6 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
|
||||
# The data type of the query
|
||||
@@ -175,8 +256,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.kv_cache_spec.block_size)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
self.enable_cuda_graph = self.compilation_config.cudagraph_mode.\
|
||||
decode_mode() == CUDAGraphMode.FULL
|
||||
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
|
||||
decode_mode() == CUDAGraphMode.FULL)
|
||||
if self.enable_cuda_graph:
|
||||
# For full cudagraph capture, one `decode_wrapper` for each batch
|
||||
# size is needed for FlashInfer.
|
||||
@@ -201,7 +282,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
if supports_trtllm_attention()[0]:
|
||||
if supports_trtllm_attention()[0] and \
|
||||
not flashinfer_disable_q_quantization():
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
self.q_data_type = self.model_config.dtype
|
||||
@@ -795,11 +877,29 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert self.o_sf_scale is None
|
||||
out = output[num_decode_tokens:]
|
||||
|
||||
if attn_metadata.q_data_type != FP8_DTYPE \
|
||||
and self.kv_cache_dtype.startswith("fp8"):
|
||||
# TRTLLM prefill attention does not support BF16 Q
|
||||
# and fp8 kv cache. So to enable prefill attention
|
||||
# with fp8 kv cache, we can construct a mock block
|
||||
# and mock kv cache with BF16 KV involved in the prefill
|
||||
mock_kv_cache, mock_block_table = (
|
||||
trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache_permute,
|
||||
block_tables_prefill,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
attn_metadata.q_data_type,
|
||||
))
|
||||
else:
|
||||
mock_kv_cache = kv_cache_permute
|
||||
mock_block_table = block_tables_prefill
|
||||
|
||||
trtllm_batch_context_with_kv_cache(
|
||||
query=prefill_query,
|
||||
kv_cache=kv_cache_permute,
|
||||
kv_cache=mock_kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables_prefill,
|
||||
block_tables=mock_block_table,
|
||||
seq_lens=seq_lens_prefill,
|
||||
max_q_len=attn_metadata.max_q_len,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
@@ -837,7 +937,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
decode_query = decode_query.contiguous()
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
block_tables_decode = attn_metadata.\
|
||||
block_table_tensor[:num_decode_tokens]
|
||||
block_table_tensor[:num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
|
||||
Reference in New Issue
Block a user