[Perf] Disable clean_logits in deepgemm fp8_mqa_logits kernel (#33568)
This commit is contained in:
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
# Test parameters
|
||||
NUM_ROWS = [1, 32, 2050]
|
||||
@@ -20,6 +21,7 @@ def create_random_logits(
|
||||
row_ends: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
clean_logits: bool,
|
||||
data_generation: str,
|
||||
) -> torch.Tensor:
|
||||
"""Create random logits tensor for testing."""
|
||||
@@ -48,8 +50,9 @@ def create_random_logits(
|
||||
)
|
||||
logits = logits_bits.view(dtype)
|
||||
|
||||
for i, end in enumerate(row_ends):
|
||||
logits[i, end:] = float("-inf")
|
||||
if clean_logits:
|
||||
for i, end in enumerate(row_ends):
|
||||
logits[i, end:] = float("-inf")
|
||||
return logits
|
||||
|
||||
|
||||
@@ -121,21 +124,26 @@ def compare_top_k_results(
|
||||
|
||||
@pytest.mark.parametrize("num_rows", NUM_ROWS)
|
||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@torch.inference_mode()
|
||||
def test_top_k_per_row(
|
||||
num_rows: int,
|
||||
top_k: int,
|
||||
clean_logits: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Test top_k_per_row.
|
||||
"""
|
||||
set_random_seed(0)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
# Create test data
|
||||
vocab_size = 20000
|
||||
row_starts, row_ends = create_row_boundaries(num_rows, vocab_size)
|
||||
logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random")
|
||||
logits = create_random_logits(
|
||||
row_starts, row_ends, torch.float32, 42, clean_logits, "random"
|
||||
)
|
||||
|
||||
# Create output tensors
|
||||
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
@@ -153,11 +161,12 @@ def test_top_k_per_row(
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
|
||||
mask_lo = torch_indices >= 0
|
||||
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
|
||||
mask = mask_lo & mask_hi
|
||||
torch_indices = torch_indices.masked_fill(~mask, -1)
|
||||
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
for i in range(num_rows):
|
||||
row_end = int(row_ends[i])
|
||||
k_i = min(top_k, row_end)
|
||||
idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
|
||||
torch_indices[i, :k_i] = idx
|
||||
|
||||
# Compare results
|
||||
assert compare_top_k_results(
|
||||
@@ -170,6 +179,7 @@ def _run_top_k_per_row_decode_test(
|
||||
batch_size: int,
|
||||
next_n: int,
|
||||
vocab_size: int,
|
||||
clean_logits: bool,
|
||||
data_generation: str,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -180,14 +190,18 @@ def _run_top_k_per_row_decode_test(
|
||||
# Create test data
|
||||
num_rows = batch_size * next_n
|
||||
seq_lens = torch.randint(
|
||||
vocab_size, (batch_size,), dtype=torch.int32, device="cuda"
|
||||
low=next_n,
|
||||
high=vocab_size,
|
||||
size=(batch_size,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
|
||||
row_indices = torch.arange(num_rows, device="cuda") // next_n
|
||||
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
|
||||
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||
logits = create_random_logits(
|
||||
row_starts, row_ends, torch.float32, 42, data_generation
|
||||
row_starts, row_ends, torch.float32, 42, clean_logits, data_generation
|
||||
)
|
||||
|
||||
# Create output tensors
|
||||
@@ -208,11 +222,12 @@ def _run_top_k_per_row_decode_test(
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Run reference implementation
|
||||
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
|
||||
mask_lo = torch_indices >= 0
|
||||
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
|
||||
mask = mask_lo & mask_hi
|
||||
torch_indices = torch_indices.masked_fill(~mask, -1)
|
||||
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
for i in range(num_rows):
|
||||
row_end = int(row_ends[i])
|
||||
k_i = min(top_k, row_end)
|
||||
idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
|
||||
torch_indices[i, :k_i] = idx
|
||||
|
||||
# Compare results
|
||||
assert compare_top_k_results(
|
||||
@@ -223,6 +238,7 @@ def _run_top_k_per_row_decode_test(
|
||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("next_n", NEXT_N)
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
@pytest.mark.parametrize("data_generation", DATA_GENERATION)
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@torch.inference_mode()
|
||||
@@ -230,28 +246,32 @@ def test_top_k_per_row_decode(
|
||||
top_k: int,
|
||||
batch_size: int,
|
||||
next_n: int,
|
||||
clean_logits: bool,
|
||||
data_generation: str,
|
||||
) -> None:
|
||||
"""
|
||||
Test top_k_per_row with seq_lens tensor.
|
||||
"""
|
||||
set_random_seed(0)
|
||||
vocab_size = 20000
|
||||
_run_top_k_per_row_decode_test(
|
||||
top_k, batch_size, next_n, vocab_size, data_generation
|
||||
top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_top_k_per_row_decode_large_vocab_size() -> None:
|
||||
def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
|
||||
"""
|
||||
Test top_k_per_row_decode with large vocabulary size.
|
||||
"""
|
||||
set_random_seed(0)
|
||||
top_k = 2048
|
||||
batch_size = 2
|
||||
next_n = 2
|
||||
vocab_size = 300000
|
||||
data_generation = "random"
|
||||
_run_top_k_per_row_decode_test(
|
||||
top_k, batch_size, next_n, vocab_size, data_generation
|
||||
top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user