[Kernel] Add FP8 support with FlashMLA backend (#22668)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
@@ -709,14 +709,15 @@ def test_swap_blocks_mla(
|
||||
@pytest.mark.parametrize("max_seq_len", [512])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
block_size, num_blocks,
|
||||
max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
@@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
block_table[b, :] = perm
|
||||
|
||||
dst = torch.zeros((total_tokens, entry_size),
|
||||
dtype=src_cache.dtype,
|
||||
device=device)
|
||||
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
|
||||
|
||||
expected_batches = []
|
||||
for b in range(batch_size):
|
||||
@@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
|
||||
gathered_rows = []
|
||||
for i in range(tot - 1):
|
||||
gathered_rows.append(src_cache[blocks[i]])
|
||||
block_data = src_cache[blocks[i]]
|
||||
if kv_cache_dtype == "fp8":
|
||||
dequantized_block = torch.empty_like(block_data, dtype=dtype)
|
||||
ops.convert_fp8(dequantized_block, block_data, scale.item())
|
||||
gathered_rows.append(dequantized_block)
|
||||
else:
|
||||
gathered_rows.append(block_data)
|
||||
remaining = s - (tot - 1) * block_size
|
||||
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
||||
last_block_data = src_cache[blocks[-1], :remaining, :]
|
||||
if kv_cache_dtype == "fp8":
|
||||
dequantized_last_block = torch.empty_like(last_block_data,
|
||||
dtype=dtype)
|
||||
ops.convert_fp8(dequantized_last_block, last_block_data,
|
||||
scale.item())
|
||||
gathered_rows.append(dequantized_last_block)
|
||||
else:
|
||||
gathered_rows.append(last_block_data)
|
||||
|
||||
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||
expected_batches.append(batch_expected)
|
||||
expected = torch.cat(expected_batches, dim=0)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.gather_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
||||
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, None),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
|
||||
cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, None)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user