[Kernel] Add FP8 support with FlashMLA backend (#22668)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Matthew Bonanni
2025-08-21 22:26:32 -04:00
committed by GitHub
parent 480bdf5a7b
commit 19fe1a0510
19 changed files with 235 additions and 109 deletions

View File

@@ -709,14 +709,15 @@ def test_swap_blocks_mla(
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
block_size, num_blocks,
max_seq_len, batch_size, dtype,
kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
@@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
expected_batches = []
for b in range(batch_size):
@@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
gathered_rows = []
for i in range(tot - 1):
gathered_rows.append(src_cache[blocks[i]])
block_data = src_cache[blocks[i]]
if kv_cache_dtype == "fp8":
dequantized_block = torch.empty_like(block_data, dtype=dtype)
ops.convert_fp8(dequantized_block, block_data, scale.item())
gathered_rows.append(dequantized_block)
else:
gathered_rows.append(block_data)
remaining = s - (tot - 1) * block_size
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
last_block_data = src_cache[blocks[-1], :remaining, :]
if kv_cache_dtype == "fp8":
dequantized_last_block = torch.empty_like(last_block_data,
dtype=dtype)
ops.convert_fp8(dequantized_last_block, last_block_data,
scale.item())
gathered_rows.append(dequantized_last_block)
else:
gathered_rows.append(last_block_data)
batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)
opcheck(
torch.ops._C_cache_ops.gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
scale, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, kv_cache_dtype,
scale, None)
torch.testing.assert_close(dst, expected)