diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index e2ae3b833..2dc522598 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -95,7 +95,8 @@ def _ref_fp8_mqa_logits( @pytest.mark.skipif( not current_platform.has_device_capability(90), reason="SM90 and SM100 only" ) -def test_deepgemm_fp8_mqa_logits(): +@pytest.mark.parametrize("clean_logits", [True, False]) +def test_deepgemm_fp8_mqa_logits(clean_logits: bool): torch.manual_seed(0) random.seed(0) num_heads, head_dim = 32, 128 @@ -126,7 +127,9 @@ def test_deepgemm_fp8_mqa_logits(): q_fp8 = q.to(torch.float8_e4m3fn) kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) - logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) + logits = fp8_mqa_logits( + q_fp8, kv_fp8, weights, ks, ke, clean_logits=clean_logits + ) ref_logits = _ref_fp8_mqa_logits( q=q, @@ -135,13 +138,14 @@ def test_deepgemm_fp8_mqa_logits(): cu_seqlen_ks=ks, cu_seqlen_ke=ke, ) - ref_neginf_mask = ref_logits == float("-inf") - neginf_mask = logits == float("-inf") - assert torch.equal(neginf_mask, ref_neginf_mask) + + if clean_logits: + neginf_mask = logits == float("-inf") + assert torch.equal(neginf_mask, ref_neginf_mask) ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) - logits = logits.masked_fill(neginf_mask, 0) + logits = logits.masked_fill(ref_neginf_mask, 0) diff = calc_diff(logits, ref_logits) assert diff < 1e-3, f"{diff=}" @@ -201,7 +205,8 @@ def _ref_fp8_paged_mqa_logits( @pytest.mark.skipif( not current_platform.has_device_capability(90), reason="SM90 and SM100 only" ) -def test_deepgemm_fp8_paged_mqa_logits(): +@pytest.mark.parametrize("clean_logits", [True, False]) +def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool): torch.manual_seed(0) random.seed(0) @@ -264,6 +269,7 @@ def test_deepgemm_fp8_paged_mqa_logits(): block_tables, schedule_metadata, max_model_len, + clean_logits=clean_logits, ) ref_logits = _ref_fp8_paged_mqa_logits( diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index 3bf693897..2d9dd2a04 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -6,6 +6,7 @@ import pytest import torch from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_random_seed # Test parameters NUM_ROWS = [1, 32, 2050] @@ -20,6 +21,7 @@ def create_random_logits( row_ends: torch.Tensor, dtype: torch.dtype, seed: int, + clean_logits: bool, data_generation: str, ) -> torch.Tensor: """Create random logits tensor for testing.""" @@ -48,8 +50,9 @@ def create_random_logits( ) logits = logits_bits.view(dtype) - for i, end in enumerate(row_ends): - logits[i, end:] = float("-inf") + if clean_logits: + for i, end in enumerate(row_ends): + logits[i, end:] = float("-inf") return logits @@ -121,21 +124,26 @@ def compare_top_k_results( @pytest.mark.parametrize("num_rows", NUM_ROWS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("clean_logits", [True, False]) @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") @torch.inference_mode() def test_top_k_per_row( num_rows: int, top_k: int, + clean_logits: bool, ) -> None: """ Test top_k_per_row. """ + set_random_seed(0) torch.set_default_device("cuda:0") # Create test data vocab_size = 20000 row_starts, row_ends = create_row_boundaries(num_rows, vocab_size) - logits = create_random_logits(row_starts, row_ends, torch.float32, 42, "random") + logits = create_random_logits( + row_starts, row_ends, torch.float32, 42, clean_logits, "random" + ) # Create output tensors indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") @@ -153,11 +161,12 @@ def test_top_k_per_row( ) # Run reference implementation - torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1] - mask_lo = torch_indices >= 0 - mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 - mask = mask_lo & mask_hi - torch_indices = torch_indices.masked_fill(~mask, -1) + torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + for i in range(num_rows): + row_end = int(row_ends[i]) + k_i = min(top_k, row_end) + idx = logits[i, :row_end].topk(k_i, dim=-1)[1] + torch_indices[i, :k_i] = idx # Compare results assert compare_top_k_results( @@ -170,6 +179,7 @@ def _run_top_k_per_row_decode_test( batch_size: int, next_n: int, vocab_size: int, + clean_logits: bool, data_generation: str, ) -> None: """ @@ -180,14 +190,18 @@ def _run_top_k_per_row_decode_test( # Create test data num_rows = batch_size * next_n seq_lens = torch.randint( - vocab_size, (batch_size,), dtype=torch.int32, device="cuda" + low=next_n, + high=vocab_size, + size=(batch_size,), + dtype=torch.int32, + device="cuda", ) row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda") row_indices = torch.arange(num_rows, device="cuda") // next_n next_n_offset = torch.arange(num_rows, device="cuda") % next_n row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1 logits = create_random_logits( - row_starts, row_ends, torch.float32, 42, data_generation + row_starts, row_ends, torch.float32, 42, clean_logits, data_generation ) # Create output tensors @@ -208,11 +222,12 @@ def _run_top_k_per_row_decode_test( torch.cuda.synchronize() # Run reference implementation - torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1] - mask_lo = torch_indices >= 0 - mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0 - mask = mask_lo & mask_hi - torch_indices = torch_indices.masked_fill(~mask, -1) + torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + for i in range(num_rows): + row_end = int(row_ends[i]) + k_i = min(top_k, row_end) + idx = logits[i, :row_end].topk(k_i, dim=-1)[1] + torch_indices[i, :k_i] = idx # Compare results assert compare_top_k_results( @@ -223,6 +238,7 @@ def _run_top_k_per_row_decode_test( @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("batch_size", BATCH_SIZE) @pytest.mark.parametrize("next_n", NEXT_N) +@pytest.mark.parametrize("clean_logits", [True, False]) @pytest.mark.parametrize("data_generation", DATA_GENERATION) @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") @torch.inference_mode() @@ -230,28 +246,32 @@ def test_top_k_per_row_decode( top_k: int, batch_size: int, next_n: int, + clean_logits: bool, data_generation: str, ) -> None: """ Test top_k_per_row with seq_lens tensor. """ + set_random_seed(0) vocab_size = 20000 _run_top_k_per_row_decode_test( - top_k, batch_size, next_n, vocab_size, data_generation + top_k, batch_size, next_n, vocab_size, clean_logits, data_generation ) @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@pytest.mark.parametrize("clean_logits", [True, False]) @torch.inference_mode() -def test_top_k_per_row_decode_large_vocab_size() -> None: +def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None: """ Test top_k_per_row_decode with large vocabulary size. """ + set_random_seed(0) top_k = 2048 batch_size = 2 next_n = 2 vocab_size = 300000 data_generation = "random" _run_top_k_per_row_decode_test( - top_k, batch_size, next_n, vocab_size, data_generation + top_k, batch_size, next_n, vocab_size, clean_logits, data_generation ) diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 77fe4c063..9ca7a42b7 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -108,6 +108,7 @@ def sparse_attn_indexer( weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, + clean_logits=False, ) num_rows = logits.shape[0] @@ -157,6 +158,7 @@ def sparse_attn_indexer( decode_metadata.block_table, decode_metadata.schedule_metadata, max_model_len=max_model_len, + clean_logits=False, ) num_rows = logits.shape[0] diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 129e9c9fa..19e85ff62 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -242,6 +242,7 @@ def fp8_mqa_logits( weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, + clean_logits: bool, ) -> torch.Tensor: """Compute FP8 MQA logits for a single sequence without KV paging. @@ -256,6 +257,7 @@ def fp8_mqa_logits( shape [M], dtype int32. cu_seqlen_ke: End indices (exclusive) for valid K per query position, shape [M], dtype int32. + clean_logits: Whether to clean the unfilled logits into `-inf`. Returns: Logits tensor of shape [M, N], dtype `torch.float32`. @@ -263,7 +265,9 @@ def fp8_mqa_logits( _lazy_init() if _fp8_mqa_logits_impl is None: return _missing() - return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + return _fp8_mqa_logits_impl( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=clean_logits + ) def get_paged_mqa_logits_metadata( @@ -295,6 +299,7 @@ def fp8_paged_mqa_logits( block_tables: torch.Tensor, schedule_metadata: torch.Tensor, max_model_len: int, + clean_logits: bool, ) -> torch.Tensor: """Compute FP8 MQA logits using paged KV-cache. @@ -312,6 +317,7 @@ def fp8_paged_mqa_logits( schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; used to distribute work across SMs. max_model_len: Maximum sequence length used to size the logits output. + clean_logits: Whether to clean the unfilled logits into `-inf`. Returns: Logits tensor of shape [B * next_n, max_model_len], dtype @@ -328,7 +334,7 @@ def fp8_paged_mqa_logits( block_tables, schedule_metadata, max_model_len, - clean_logits=True, + clean_logits=clean_logits, )