[NVIDIA] Support Flashinfer TRTLLM FP8-q/kv/out Attention Kernel (#21716)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -15,12 +15,17 @@ from flashinfer.decode import (_get_range_buf, get_seq_lens,
|
||||
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import use_trtllm_attention
|
||||
from vllm.utils.flashinfer import (supports_trtllm_attention,
|
||||
use_trtllm_attention)
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@@ -35,6 +40,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -519,22 +526,27 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
else:
|
||||
kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
|
||||
num_qo_heads = self.vllm_config.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config)
|
||||
config = self.vllm_config
|
||||
num_qo_heads = config.model_config.get_num_attention_heads(
|
||||
config.parallel_config)
|
||||
num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
head_dim = self.kv_cache_spec.head_size
|
||||
|
||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||
has_sinks = self.global_hyperparameters.has_sinks
|
||||
|
||||
# currently prefill trtllm attention does not support fp8 kv cache
|
||||
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
|
||||
and use_trtllm_attention(
|
||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
|
||||
q_dtype = config.model_config.dtype
|
||||
enable_fusion = config.compilation_config.pass_config.enable_attn_fusion
|
||||
if cache_dtype.startswith("fp8") and enable_fusion:
|
||||
q_dtype = kv_cache_dtype
|
||||
|
||||
prefill_use_trtllm = use_trtllm_attention(
|
||||
num_qo_heads, num_kv_heads, num_prefill_tokens, max_seq_len,
|
||||
cache_dtype, q_dtype, is_prefill=True, has_sinks=has_sinks)
|
||||
decode_use_trtllm = use_trtllm_attention(
|
||||
num_decode_tokens, max_seq_len, cache_dtype,
|
||||
num_qo_heads, num_kv_heads, head_dim, has_sinks)
|
||||
num_qo_heads, num_kv_heads, num_decode_tokens, max_seq_len,
|
||||
cache_dtype, q_dtype, is_prefill=False, has_sinks=has_sinks)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -548,7 +560,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
head_dim=head_dim,
|
||||
page_size=page_size,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
q_data_type=self.vllm_config.model_config.dtype,
|
||||
q_data_type=q_dtype,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
max_q_len=max_q_len,
|
||||
max_seq_len=max_seq_len,
|
||||
@@ -622,6 +634,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.window_left = (self.sliding_window[0]
|
||||
if self.sliding_window is not None else -1)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
@@ -644,6 +658,19 @@ class FlashInferImpl(AttentionImpl):
|
||||
)
|
||||
self.sinks = sinks
|
||||
|
||||
self.support_trtllm_attn = (supports_trtllm_attention() and
|
||||
num_heads % num_kv_heads == 0)
|
||||
self.bmm1_scale: Optional[float] = None
|
||||
self.bmm2_scale: Optional[float] = None
|
||||
|
||||
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
|
||||
group_shape: GroupShape):
|
||||
supported_quant_type = (dtype == FP8_DTYPE and static and
|
||||
group_shape == GroupShape.PER_TENSOR)
|
||||
return (self.support_trtllm_attn
|
||||
and self.kv_cache_dtype.startswith("fp8")
|
||||
and supported_quant_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -672,15 +699,42 @@ class FlashInferImpl(AttentionImpl):
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashInferImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
|
||||
self.scale)
|
||||
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
# The attn+quant fusion happens when output_scale is provided.
|
||||
if output_scale is None:
|
||||
assert attn_metadata.q_data_type != FP8_DTYPE, \
|
||||
"Query can only be FP8 if output fusion happened."
|
||||
else:
|
||||
assert attn_metadata.q_data_type == FP8_DTYPE, \
|
||||
"Query must be FP8 when attn+quant fusion happened."
|
||||
assert (attn_metadata.prefill_use_trtllm and
|
||||
attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn"
|
||||
assert output.dtype == FP8_DTYPE, \
|
||||
"Output must be FP8 when attn+quant fusion happened."
|
||||
|
||||
# TRTLLM attn kernel requires o scale as a host scalar, store the
|
||||
# o scale to host scalar in warmup run with cuda graph not enabled
|
||||
if layer._o_scale_float is None:
|
||||
layer._o_scale_float = output_scale.cpu().item()
|
||||
self.bmm2_scale = self.bmm2_scale / layer._o_scale_float
|
||||
|
||||
# Insert FP8 quant for query
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
@@ -718,9 +772,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
self.kv_cache_dtype)
|
||||
kv_cache = kv_cache.view(torch_dtype)
|
||||
|
||||
window_left = (self.sliding_window[0]
|
||||
if self.sliding_window is not None else -1)
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
query = query[:num_actual_tokens]
|
||||
output_padded = output
|
||||
@@ -748,7 +799,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
assert prefill_wrapper._causal
|
||||
assert prefill_wrapper._window_left == window_left
|
||||
assert prefill_wrapper._window_left == self.window_left
|
||||
assert prefill_wrapper._logits_soft_cap == (
|
||||
self.logits_soft_cap or 0.0)
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
@@ -783,12 +834,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
seq_lens=seq_lens_prefill,
|
||||
max_q_len=attn_metadata.max_q_len,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
cum_seq_lens_q=attn_metadata.qo_indptr_gpu,
|
||||
cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu,
|
||||
window_left=window_left,
|
||||
window_left=self.window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
@@ -800,7 +851,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert decode_wrapper is not None
|
||||
|
||||
if not attn_metadata.decode_use_trtllm:
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._window_left == self.window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
@@ -815,8 +866,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
# decode_query may be non-contiguous
|
||||
decode_query = decode_query.contiguous()
|
||||
workspace_buffer = decode_wrapper._float_workspace_buffer
|
||||
block_tables_decode = attn_metadata.block_table_tensor[:
|
||||
num_decode_tokens]
|
||||
block_tables_decode = attn_metadata.\
|
||||
block_table_tensor[:num_decode_tokens]
|
||||
seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens]
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
@@ -834,9 +885,9 @@ class FlashInferImpl(AttentionImpl):
|
||||
block_tables=block_tables_decode,
|
||||
seq_lens=seq_lens_decode,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=layer._k_scale_float * self.scale,
|
||||
bmm2_scale=layer._v_scale_float,
|
||||
window_left=window_left,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
window_left=self.window_left,
|
||||
sinks=self.sinks,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user