[Core] Add Flashinfer TRTLLM Backend for Flashinfer decode path (SM100). (#19825)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: shuw <shuw@nvidia.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -10,11 +10,13 @@ import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper)
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
@@ -38,6 +40,7 @@ logger = init_logger(__name__)
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
cached_sm100a_supported: Optional[bool] = None
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
@@ -93,6 +96,57 @@ class FlashInferBackend(AttentionBackend):
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
@staticmethod
|
||||
def use_trtllm_decode_attention(
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
attn_head_size: int,
|
||||
) -> bool:
|
||||
if FlashInferBackend.cached_sm100a_supported is None:
|
||||
FlashInferBackend.cached_sm100a_supported = (
|
||||
current_platform.has_device_capability(100))
|
||||
if not FlashInferBackend.cached_sm100a_supported:
|
||||
return False
|
||||
if (num_qo_heads // num_kv_heads > 8
|
||||
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||
return False
|
||||
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
|
||||
env_value)
|
||||
# Environment variable is set - respect it
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
no_use_trtllm = env_value == "0"
|
||||
if not no_use_trtllm:
|
||||
logger.info_once(
|
||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
|
||||
"using TRTLLM decode attention.")
|
||||
return not no_use_trtllm
|
||||
else:
|
||||
# Environment variable not set - use auto-detection
|
||||
# Only supports attention head size of 128
|
||||
use_trtllm = (FlashInferBackend.cached_sm100a_supported
|
||||
and batch_size <= 256 and max_seq_len < 131072
|
||||
and kv_cache_dtype == "auto")
|
||||
if use_trtllm:
|
||||
logger.warning_once(
|
||||
"Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
@staticmethod
|
||||
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
return torch.float8_e4m3fn
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
return torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
@@ -127,12 +181,18 @@ class FlashInferMetadata:
|
||||
# Block size of vllm
|
||||
page_size: int
|
||||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype
|
||||
kv_data_type: torch.dtype
|
||||
# The data type of the query
|
||||
q_data_type: torch.dtype
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For flashinfer trtllm batch decode
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table_tensor: torch.Tensor
|
||||
workspace_buffer: torch.Tensor
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
@@ -299,6 +359,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
else:
|
||||
# Regular attention (common case).
|
||||
@@ -334,28 +395,33 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.data_type,
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
|
||||
if self._num_decodes > 0:
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[:self._num_decodes],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.data_type,
|
||||
)
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
self._num_decodes, attn_metadata.max_seq_len,
|
||||
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[:self.
|
||||
_num_decodes],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.kv_data_type,
|
||||
)
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
@@ -368,6 +434,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.runner.device
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
@@ -416,7 +483,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
|
||||
cache_dtype = self.runner.cache_config.cache_dtype
|
||||
if cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
cache_dtype)
|
||||
else:
|
||||
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
qo_indptr=qo_indptr,
|
||||
@@ -427,7 +499,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
head_dim=self.kv_cache_spec.head_size,
|
||||
page_size=page_size,
|
||||
data_type=self.kv_cache_spec.dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
q_data_type=self.runner.dtype,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=self._num_decodes,
|
||||
@@ -439,6 +511,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
shared_kv_page_indptr=shared_kv_page_indptr,
|
||||
shared_kv_page_indices=shared_kv_page_indices,
|
||||
shared_kv_last_page_len=shared_kv_last_page_len,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table_tensor=block_table_tensor,
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
)
|
||||
|
||||
self._plan(attn_metadata)
|
||||
@@ -514,7 +590,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
kv_cache: shape -
|
||||
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
|
||||
|
||||
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@@ -560,6 +640,13 @@ class FlashInferImpl(AttentionImpl):
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
window_left = (self.sliding_window[0]
|
||||
if self.sliding_window is not None else -1)
|
||||
|
||||
@@ -597,21 +684,45 @@ class FlashInferImpl(AttentionImpl):
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
if decode_wrapper := attn_metadata.decode_wrapper:
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
attn_metadata.num_decodes, attn_metadata.max_seq_len,
|
||||
self.kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
assert decode_wrapper is not None
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
else:
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
if num_decode_tokens > 0:
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
output[:num_decode_tokens] = (
|
||||
trtllm_batch_decode_with_kv_cache(
|
||||
query=decode_query,
|
||||
kv_cache=kv_cache.permute(*stride_order),
|
||||
workspace_buffer=attn_metadata.workspace_buffer,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.
|
||||
block_table_tensor[:num_decode_tokens],
|
||||
seq_lens=attn_metadata.
|
||||
seq_lens[:num_decode_tokens],
|
||||
block_size=attn_metadata.page_size,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
))
|
||||
return output_padded
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_KV_CACHE_LAYOUT_OVERRIDE = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -103,6 +104,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||
# Override with format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
if cache_layout is None:
|
||||
@@ -110,10 +112,16 @@ def get_kv_cache_layout():
|
||||
else:
|
||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||
|
||||
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
||||
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
||||
return cache_layout
|
||||
|
||||
|
||||
def set_kv_cache_layout(cache_layout: str):
|
||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user