[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -9,8 +9,11 @@ from typing import Optional
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@@ -61,13 +64,13 @@ def benchmark_decode(
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_seq_len
|
||||
@@ -75,14 +78,13 @@ def benchmark_decode(
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
@@ -142,11 +144,31 @@ def benchmark_decode(
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
@@ -158,6 +180,7 @@ def benchmark_decode(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
@@ -237,6 +260,7 @@ if __name__ == "__main__":
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
|
||||
Reference in New Issue
Block a user