[Perf][Kernel] Persistent TopK scheduler: unified CUDAGraph-safe kernel with dynamic per-row dispatch - DeepSeek-V3.2 DSA decode (#37421)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
75e01a39a1
commit
b55d830ec7
@@ -122,6 +122,39 @@ def compare_top_k_results(
|
||||
return True
|
||||
|
||||
|
||||
def validate_topk_against_reference(
|
||||
logits: torch.Tensor,
|
||||
cuda_indices: torch.Tensor,
|
||||
row_starts: torch.Tensor,
|
||||
row_ends: torch.Tensor,
|
||||
top_k: int,
|
||||
kernel_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Validate CUDA top-k results against PyTorch reference implementation.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor
|
||||
cuda_indices: CUDA kernel output indices
|
||||
row_starts: Row start positions
|
||||
row_ends: Row end positions
|
||||
top_k: Number of top elements to select
|
||||
kernel_name: Name of the kernel being tested (for error messages)
|
||||
"""
|
||||
num_rows = cuda_indices.shape[0]
|
||||
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
|
||||
for i in range(num_rows):
|
||||
row_end = int(row_ends[i])
|
||||
k_i = min(top_k, row_end)
|
||||
idx = logits[i, :row_end].topk(k_i, dim=-1)[1]
|
||||
torch_indices[i, :k_i] = idx
|
||||
|
||||
assert compare_top_k_results(
|
||||
logits, cuda_indices, torch_indices, row_starts, row_ends, top_k
|
||||
), f"{kernel_name} results don't match torch.topk"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_rows", NUM_ROWS)
|
||||
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
@@ -278,111 +311,540 @@ def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@pytest.mark.parametrize(
|
||||
"seq_len_range,test_id",
|
||||
[
|
||||
pytest.param((4000, 8000), "short_sequences", id="short"),
|
||||
pytest.param((8000, 32000), "medium_sequences", id="medium"),
|
||||
pytest.param((32000, 163840), "long_sequences", id="long"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
@pytest.mark.parametrize("top_k", [2048])
|
||||
@pytest.mark.parametrize("next_n", [1, 4])
|
||||
@torch.inference_mode()
|
||||
def test_deepseek_hybrid_topk(clean_logits: bool) -> None:
|
||||
def test_deepseek_persistent_topk(
|
||||
seq_len_range: tuple[int, int],
|
||||
test_id: str,
|
||||
clean_logits: bool,
|
||||
top_k: int,
|
||||
next_n: int,
|
||||
) -> None:
|
||||
"""
|
||||
Test persistent_topk with varying sequence lengths and speculative decoding.
|
||||
Supports speculative decoding with next_n > 1.
|
||||
"""
|
||||
set_random_seed(42 if test_id == "short_sequences" else 43)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
batch_size = 4
|
||||
num_rows = batch_size * next_n
|
||||
|
||||
seq_lens = torch.randint(
|
||||
seq_len_range[0],
|
||||
seq_len_range[1],
|
||||
(batch_size,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Compute row boundaries for speculative decoding
|
||||
row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda")
|
||||
row_indices = torch.arange(num_rows, device="cuda") // next_n
|
||||
next_n_offset = torch.arange(num_rows, device="cuda") % next_n
|
||||
row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1
|
||||
|
||||
logits = create_random_logits(
|
||||
row_starts, row_ends, torch.float32, 42, clean_logits, "random"
|
||||
)
|
||||
|
||||
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
|
||||
if next_n == 1:
|
||||
lengths = seq_lens
|
||||
else:
|
||||
offsets = torch.arange(next_n, device=logits.device, dtype=torch.int32)
|
||||
lengths = (seq_lens.unsqueeze(1) - next_n + 1 + offsets).flatten()
|
||||
|
||||
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
|
||||
max_seq_len = int(seq_lens.max().item())
|
||||
torch.ops._C.persistent_topk(
|
||||
logits, lengths, indices, workspace, top_k, max_seq_len
|
||||
)
|
||||
|
||||
validate_topk_against_reference(
|
||||
logits, indices, row_starts, row_ends, top_k, f"persistent_topk ({test_id})"
|
||||
)
|
||||
|
||||
|
||||
def run_large_context_topk_test(
|
||||
batch_size: int,
|
||||
seq_lens: list[int],
|
||||
top_k: int,
|
||||
data_type: str = "random",
|
||||
seed: int = 42,
|
||||
) -> None:
|
||||
"""
|
||||
Helper to run persistent_topk kernel test with given parameters.
|
||||
|
||||
Args:
|
||||
batch_size: Number of rows/sequences
|
||||
seq_lens: List of sequence lengths (one per row)
|
||||
top_k: Number of top elements to select
|
||||
data_type: Type of test data to generate
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
torch.set_default_device("cuda:0")
|
||||
set_random_seed(seed)
|
||||
|
||||
# Create test data
|
||||
num_rows = batch_size
|
||||
max_len = max(seq_lens)
|
||||
lengths = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
||||
|
||||
if data_type == "random":
|
||||
logits = torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||
elif data_type == "sorted_asc":
|
||||
# Each row gets its own ascending sequence based on its length
|
||||
logits = torch.empty(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||
for i, length in enumerate(seq_lens):
|
||||
logits[i, :length] = torch.arange(
|
||||
length, dtype=torch.float32, device="cuda"
|
||||
)
|
||||
if length < max_len:
|
||||
logits[i, length:] = float("-inf")
|
||||
elif data_type == "sorted_desc":
|
||||
# Each row gets its own descending sequence based on its length
|
||||
logits = torch.empty(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||
for i, length in enumerate(seq_lens):
|
||||
logits[i, :length] = torch.arange(
|
||||
length, 0, -1, dtype=torch.float32, device="cuda"
|
||||
)
|
||||
if length < max_len:
|
||||
logits[i, length:] = float("-inf")
|
||||
elif data_type == "all_same":
|
||||
logits = torch.ones(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||
for i, length in enumerate(seq_lens):
|
||||
if length < max_len:
|
||||
logits[i, length:] = float("-inf")
|
||||
elif data_type == "many_ties":
|
||||
# Only 10 unique values, many duplicates
|
||||
logits = torch.randint(0, 10, (num_rows, max_len), device="cuda").float() / 10.0
|
||||
for i, length in enumerate(seq_lens):
|
||||
if length < max_len:
|
||||
logits[i, length:] = float("-inf")
|
||||
elif data_type == "small_differences":
|
||||
# Very small differences to test float precision
|
||||
base = torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda")
|
||||
noise = (
|
||||
torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda") * 1e-6
|
||||
)
|
||||
logits = base + noise
|
||||
for i, length in enumerate(seq_lens):
|
||||
if length < max_len:
|
||||
logits[i, length:] = float("-inf")
|
||||
else:
|
||||
raise ValueError(f"Unknown data_type: {data_type}")
|
||||
|
||||
# Create output tensor
|
||||
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
|
||||
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
|
||||
max_seq_len = max(seq_lens)
|
||||
torch.ops._C.persistent_topk(
|
||||
logits, lengths, indices, workspace, top_k, max_seq_len
|
||||
)
|
||||
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")
|
||||
for i in range(num_rows):
|
||||
length = seq_lens[i]
|
||||
k_i = min(top_k, length)
|
||||
if k_i > 0:
|
||||
idx = logits[i, :length].topk(k_i, dim=-1)[1]
|
||||
torch_indices[i, :k_i] = idx
|
||||
if k_i < top_k:
|
||||
torch_indices[i, k_i:] = -1
|
||||
else:
|
||||
torch_indices[i, :] = -1
|
||||
|
||||
# Compare results
|
||||
for i in range(num_rows):
|
||||
length = seq_lens[i]
|
||||
k_i = min(top_k, length)
|
||||
|
||||
if k_i == 0:
|
||||
continue
|
||||
|
||||
cuda_row = indices[i, :k_i].cpu()
|
||||
torch_row = torch_indices[i, :k_i].cpu()
|
||||
|
||||
# Filter out -1 padding values from cuda_row
|
||||
valid_mask = cuda_row >= 0
|
||||
cuda_row = cuda_row[valid_mask]
|
||||
|
||||
# Compare sets (order may differ for ties)
|
||||
cuda_set = set(cuda_row.tolist())
|
||||
torch_set = set(torch_row.tolist())
|
||||
|
||||
if cuda_set == torch_set:
|
||||
continue
|
||||
|
||||
# If sets differ, check if it's due to equal values (ties)
|
||||
cuda_vals = logits[i, cuda_row].cpu()
|
||||
torch_vals = logits[i, torch_row].cpu()
|
||||
|
||||
# Check that min CUDA value >= max of values NOT in top-k
|
||||
if k_i < length:
|
||||
non_topk_indices = torch.tensor(
|
||||
list(set(range(length)) - cuda_set), dtype=torch.int32
|
||||
)
|
||||
if len(non_topk_indices) > 0:
|
||||
non_topk_vals = logits[i, non_topk_indices].cpu()
|
||||
min_cuda_val = cuda_vals.min()
|
||||
max_non_topk = non_topk_vals.max()
|
||||
|
||||
# Allow small tolerance for floating point errors
|
||||
assert min_cuda_val >= max_non_topk - 1e-4, (
|
||||
f"Row {i}: CUDA top-k contains values smaller than non-top-k. "
|
||||
f"Min CUDA: {min_cuda_val}, Max non-top-k: {max_non_topk}, "
|
||||
f"Length: {length}, k: {k_i}, CUDA indices: {sorted(cuda_set)[:10]}..., " # noqa: E501
|
||||
f"Expected indices: {sorted(torch_set)[:10]}..."
|
||||
)
|
||||
|
||||
# For ties, verify the values are close
|
||||
assert torch.allclose(
|
||||
cuda_vals.sort(descending=True)[0],
|
||||
torch_vals.sort(descending=True)[0],
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
), f"""Row {i}: Top-k values don't match.
|
||||
CUDA: {cuda_vals.sort(descending=True)[0][:10]},
|
||||
Torch: {torch_vals.sort(descending=True)[0][:10]}"""
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@pytest.mark.parametrize(
|
||||
"test_config",
|
||||
[
|
||||
# ==================== CATEGORY: Sequence Length Edge Cases ====================
|
||||
pytest.param(
|
||||
{"seq_lens": [1, 10, 100, 2048], "top_k": 2048, "data_type": "random"},
|
||||
id="seq_len_edge_very_small_to_medium",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [2049, 2100, 2500, 3000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="seq_len_edge_above_k",
|
||||
),
|
||||
pytest.param(
|
||||
{"seq_lens": [8000, 16384, 20000], "top_k": 2048, "data_type": "random"},
|
||||
id="algo_transition_filtered_radix",
|
||||
),
|
||||
# ==================== CATEGORY: Data Distributions ====================
|
||||
pytest.param(
|
||||
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "sorted_asc"},
|
||||
id="data_sorted_ascending",
|
||||
),
|
||||
pytest.param(
|
||||
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "sorted_desc"},
|
||||
id="data_sorted_descending",
|
||||
),
|
||||
pytest.param(
|
||||
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "all_same"},
|
||||
id="data_all_same",
|
||||
),
|
||||
pytest.param(
|
||||
{"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "many_ties"},
|
||||
id="data_many_ties",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [5000, 10000],
|
||||
"top_k": 2048,
|
||||
"data_type": "small_differences",
|
||||
},
|
||||
id="data_float_precision",
|
||||
),
|
||||
# ==================== CATEGORY: Alignment / Vectorization ====================
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [2055, 2056, 2057, 2063],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="align_vec_boundaries_low",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [4095, 4096, 4097, 4102],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="align_4k_boundary",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [8191, 8192, 8193, 8198],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="align_8k_boundary",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [16383, 16384, 16385, 16390],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="align_16k_boundary",
|
||||
),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_persistent_topk_correctness(test_config: dict) -> None:
|
||||
"""
|
||||
Comprehensive correctness tests covering:
|
||||
- Sequence length edge cases (trivial, boundary, varied)
|
||||
- Very small sequences (< 100 elements)
|
||||
- Mixed sequence lengths in same batch
|
||||
- Data distributions (sorted, ties, precision)
|
||||
- Memory alignment / vectorization boundaries
|
||||
"""
|
||||
run_large_context_topk_test(
|
||||
batch_size=len(test_config["seq_lens"]),
|
||||
seq_lens=test_config["seq_lens"],
|
||||
top_k=test_config["top_k"],
|
||||
data_type=test_config.get("data_type", "random"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@pytest.mark.parametrize(
|
||||
"test_config",
|
||||
[
|
||||
# ==================== CATEGORY: Batch Size Scalability ====================
|
||||
pytest.param(
|
||||
{"batch_size": 1, "seq_len": 5000, "top_k": 2048},
|
||||
id="batch_1",
|
||||
),
|
||||
pytest.param(
|
||||
{"batch_size": 4, "seq_len": 5000, "top_k": 2048},
|
||||
id="batch_4",
|
||||
),
|
||||
pytest.param(
|
||||
{"batch_size": 32, "seq_len": 5000, "top_k": 2048},
|
||||
id="batch_32",
|
||||
),
|
||||
pytest.param(
|
||||
{"batch_size": 256, "seq_len": 5000, "top_k": 2048},
|
||||
id="batch_256",
|
||||
),
|
||||
# ==================== CATEGORY: Single-CTA vs Multi-CTA ====================
|
||||
pytest.param(
|
||||
{"batch_size": 2, "seq_len": 4096, "top_k": 2048},
|
||||
id="single_cta_4k",
|
||||
),
|
||||
pytest.param(
|
||||
{"batch_size": 2, "seq_len": 8192, "top_k": 2048},
|
||||
id="single_cta_8k",
|
||||
),
|
||||
pytest.param(
|
||||
{"batch_size": 2, "seq_len": 163840, "top_k": 2048},
|
||||
id="multi_cta_163840_dsv3_max",
|
||||
),
|
||||
# ==================== CATEGORY: Extreme Cases ====================
|
||||
pytest.param(
|
||||
{"batch_size": 512, "seq_len": 5000, "top_k": 2048},
|
||||
id="extreme_large_batch",
|
||||
),
|
||||
pytest.param(
|
||||
{"batch_size": 2, "seq_len": 163840, "top_k": 2048},
|
||||
id="extreme_dsv3_max_context",
|
||||
),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_persistent_topk_algorithm_paths(test_config: dict) -> None:
|
||||
"""
|
||||
Test different algorithm execution paths (capped at 163840 for DeepSeek V3.2):
|
||||
- Batch size scalability (1, 4, 32, 256)
|
||||
- Single-CTA vs Multi-CTA execution
|
||||
- Extreme configurations (large batch, max context length)
|
||||
"""
|
||||
run_large_context_topk_test(
|
||||
batch_size=test_config["batch_size"],
|
||||
seq_lens=[test_config["seq_len"]] * test_config["batch_size"],
|
||||
top_k=test_config["top_k"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@torch.inference_mode()
|
||||
def test_persistent_topk_stress() -> None:
|
||||
"""
|
||||
Stress test with random configurations to catch edge cases.
|
||||
Capped at 163840 (DeepSeek V3.2 max context) for realistic testing.
|
||||
"""
|
||||
torch.set_default_device("cuda:0")
|
||||
top_k = 2048
|
||||
|
||||
for seed in range(3):
|
||||
set_random_seed(seed)
|
||||
|
||||
# Random batch size (limited for speed)
|
||||
batch_size = torch.randint(1, 32, (1,)).item()
|
||||
|
||||
# Random sequence lengths capped at DeepSeek V3.2 max context
|
||||
seq_lens = torch.randint(100, 163840, (batch_size,)).tolist()
|
||||
|
||||
run_large_context_topk_test(
|
||||
batch_size=batch_size,
|
||||
seq_lens=seq_lens,
|
||||
top_k=top_k,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@pytest.mark.parametrize(
|
||||
"test_config",
|
||||
[
|
||||
# Mixed batch: rows spanning all four paths (trivial, decode, medium, large)
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [2000, 6000, 30000, 80000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="mixed_all_paths",
|
||||
),
|
||||
# All decode/medium rows (typical decode scenario)
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [2048, 4096, 8192, 16000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="all_decode_medium",
|
||||
),
|
||||
# All large rows
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [70000, 100000, 163840],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="all_large",
|
||||
),
|
||||
# Boundary around LARGE_THRESHOLD (32K)
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [32767, 32768, 32769, 32772],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="large_threshold_boundary",
|
||||
),
|
||||
# Single row medium
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [5000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="single_row_medium",
|
||||
),
|
||||
# Single row large
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [100000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="single_row_large",
|
||||
),
|
||||
# Trivial rows mixed with medium and large
|
||||
pytest.param(
|
||||
{
|
||||
"seq_lens": [100, 2048, 10000, 80000],
|
||||
"top_k": 2048,
|
||||
"data_type": "random",
|
||||
},
|
||||
id="trivial_medium_large_mix",
|
||||
),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_persistent_topk(test_config: dict) -> None:
|
||||
"""
|
||||
Tests specific to the persistent_topk kernel:
|
||||
- Mixed medium/large rows in the same batch (dynamic per-row dispatch)
|
||||
- Boundary around LARGE_THRESHOLD (32K)
|
||||
- Trivial + medium + large rows in a single batch
|
||||
"""
|
||||
run_large_context_topk_test(
|
||||
batch_size=len(test_config["seq_lens"]),
|
||||
seq_lens=test_config["seq_lens"],
|
||||
top_k=test_config["top_k"],
|
||||
data_type=test_config.get("data_type", "random"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@torch.inference_mode()
|
||||
def test_persistent_topk_padded_stride() -> None:
|
||||
"""
|
||||
Test persistent_topk with padded logits (large stride, small seq_len)
|
||||
to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits
|
||||
returns [B, max_model_len] with max_model_len=163840.
|
||||
"""
|
||||
set_random_seed(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
top_k = 2048
|
||||
batch_size = 4
|
||||
padded_stride = 163840 # DeepSeek-V3.2 max_model_len
|
||||
actual_seq_lens = [3000, 5000, 8000, 12000]
|
||||
|
||||
# Test case 1: Short sequences (< 8192)
|
||||
batch_size_short = 4
|
||||
next_n = 1
|
||||
num_rows_short = batch_size_short * next_n
|
||||
|
||||
# Create sequences with max length < 8192
|
||||
seq_lens_short = torch.randint(
|
||||
4000, 8000, (batch_size_short,), dtype=torch.int32, device="cuda"
|
||||
# Create padded logits tensor (like fp8_paged_mqa_logits output)
|
||||
logits = torch.full(
|
||||
(batch_size, padded_stride),
|
||||
float("-inf"),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
for i, sl in enumerate(actual_seq_lens):
|
||||
logits[i, :sl] = torch.randn(sl, dtype=torch.float32, device="cuda")
|
||||
|
||||
row_starts_short = torch.zeros(num_rows_short, dtype=torch.int32, device="cuda")
|
||||
row_indices_short = torch.arange(num_rows_short, device="cuda") // next_n
|
||||
next_n_offset_short = torch.arange(num_rows_short, device="cuda") % next_n
|
||||
row_ends_short = (
|
||||
seq_lens_short[row_indices_short] - next_n + next_n_offset_short + 1
|
||||
lengths = torch.tensor(actual_seq_lens, dtype=torch.int32, device="cuda")
|
||||
indices = torch.empty((batch_size, top_k), dtype=torch.int32, device="cuda")
|
||||
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
|
||||
|
||||
torch.ops._C.persistent_topk(
|
||||
logits, lengths, indices, workspace, top_k, max(actual_seq_lens)
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
logits_short = create_random_logits(
|
||||
row_starts_short, row_ends_short, torch.float32, 42, clean_logits, "random"
|
||||
)
|
||||
# Validate against torch.topk
|
||||
for i in range(batch_size):
|
||||
sl = actual_seq_lens[i]
|
||||
k_i = min(top_k, sl)
|
||||
expected = logits[i, :sl].topk(k_i, dim=-1)[1].cpu()
|
||||
actual = indices[i, :k_i].cpu()
|
||||
|
||||
indices_vllm = torch.empty(
|
||||
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
expected_set = set(expected.tolist())
|
||||
actual_set = set(actual.tolist())
|
||||
|
||||
# Use vllm's kernel for short sequences
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits_short,
|
||||
next_n,
|
||||
seq_lens_short,
|
||||
indices_vllm,
|
||||
num_rows_short,
|
||||
logits_short.stride(0),
|
||||
logits_short.stride(1),
|
||||
top_k,
|
||||
)
|
||||
|
||||
# Test case 2: Long sequences (>= 8192) - should use large_context_topk kernel
|
||||
batch_size_long = 4
|
||||
num_rows_long = batch_size_long * next_n
|
||||
|
||||
# Create sequences with max length >= 8192
|
||||
seq_lens_long = torch.randint(
|
||||
8192, 16384, (batch_size_long,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
row_starts_long = torch.zeros(num_rows_long, dtype=torch.int32, device="cuda")
|
||||
row_indices_long = torch.arange(num_rows_long, device="cuda") // next_n
|
||||
next_n_offset_long = torch.arange(num_rows_long, device="cuda") % next_n
|
||||
row_ends_long = seq_lens_long[row_indices_long] - next_n + next_n_offset_long + 1
|
||||
|
||||
logits_long = create_random_logits(
|
||||
row_starts_long, row_ends_long, torch.float32, 43, clean_logits, "random"
|
||||
)
|
||||
|
||||
indices = torch.empty((num_rows_long, top_k), dtype=torch.int32, device="cuda")
|
||||
|
||||
# Use large_context_topk kernel for long sequences
|
||||
if next_n == 1:
|
||||
lengths = seq_lens_long
|
||||
else:
|
||||
offsets = torch.arange(next_n, device=logits_long.device, dtype=torch.int32)
|
||||
lengths = (seq_lens_long.unsqueeze(1) - next_n + 1 + offsets).flatten()
|
||||
|
||||
torch.ops._C.large_context_topk(
|
||||
logits_long,
|
||||
indices,
|
||||
lengths,
|
||||
None,
|
||||
)
|
||||
|
||||
torch_indices_short = torch.empty(
|
||||
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
for i in range(num_rows_short):
|
||||
row_end = int(row_ends_short[i])
|
||||
k_i = min(top_k, row_end)
|
||||
idx = logits_short[i, :row_end].topk(k_i, dim=-1)[1]
|
||||
torch_indices_short[i, :k_i] = idx
|
||||
|
||||
assert compare_top_k_results(
|
||||
logits_short,
|
||||
indices_vllm,
|
||||
torch_indices_short,
|
||||
row_starts_short,
|
||||
row_ends_short,
|
||||
top_k,
|
||||
), "top_k_per_row_decode kernel (short sequences) doesn't match torch.topk"
|
||||
|
||||
torch_indices_long = torch.empty(
|
||||
(num_rows_long, top_k), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
for i in range(num_rows_long):
|
||||
row_end = int(row_ends_long[i])
|
||||
k_i = min(top_k, row_end)
|
||||
idx = logits_long[i, :row_end].topk(k_i, dim=-1)[1]
|
||||
torch_indices_long[i, :k_i] = idx
|
||||
|
||||
assert compare_top_k_results(
|
||||
logits_long, indices, torch_indices_long, row_starts_long, row_ends_long, top_k
|
||||
), "large_context_topk kernel (long sequences) doesn't match torch.topk"
|
||||
if expected_set != actual_set:
|
||||
# Allow ties
|
||||
expected_vals = logits[i, expected].cpu().sort(descending=True)[0]
|
||||
actual_vals = logits[i, actual].cpu().sort(descending=True)[0]
|
||||
assert torch.allclose(expected_vals, actual_vals, rtol=1e-4, atol=1e-4), (
|
||||
f"Row {i}: persistent_topk with padded stride doesn't match. "
|
||||
f"seq_len={sl}, stride={padded_stride}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user