diff --git a/tests/v1/sample/test_batched_count_greater_than.py b/tests/v1/sample/test_batched_count_greater_than.py new file mode 100644 index 000000000..c9ace93c6 --- /dev/null +++ b/tests/v1/sample/test_batched_count_greater_than.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test that batched_count_greater_than does not trigger 0/1 specialization +recompiles when batch_size varies.""" + +import torch + +from vllm.platforms import current_platform +from vllm.v1.sample.ops.logprobs import batched_count_greater_than +from vllm.v1.sample.sampler import Sampler + +DEVICE = current_platform.device_type + + +def test_batched_count_greater_than_correctness(): + """Basic correctness: counts elements >= the corresponding value.""" + x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device=DEVICE) + values = torch.tensor([[2.0], [5.0]], device=DEVICE) + result = batched_count_greater_than(x, values) + expected = torch.tensor([2, 2], device=DEVICE) + torch.testing.assert_close(result, expected) + + +def test_gather_logprobs_no_recompile(): + """Sampler.gather_logprobs with batch_size=1 then 2 must not recompile. + + This guards against 0/1 specialization: dynamo normally specializes on + tensor sizes 0 and 1, causing a recompile when the size first exceeds 1. + The mark_unbacked calls in gather_logprobs prevent this. + """ + torch._dynamo.reset() + + compile_count = 0 + orig_backend = current_platform.simple_compile_backend + + def counting_backend(gm, example_inputs): + nonlocal compile_count + compile_count += 1 + if orig_backend == "inductor": + return torch._inductor.compile(gm, example_inputs) + return gm + + # Monkey-patch batched_count_greater_than with our counting backend + # so we can detect recompiles through the production code path. + import vllm.v1.sample.ops.logprobs as logprobs_module + import vllm.v1.sample.sampler as sampler_module + + unwrapped = batched_count_greater_than._torchdynamo_orig_callable + patched = torch.compile(unwrapped, backend=counting_backend) + orig_fn = logprobs_module.batched_count_greater_than + + logprobs_module.batched_count_greater_than = patched + sampler_module.batched_count_greater_than = patched + + try: + vocab_size = 32 + num_logprobs = 3 + + # Call 1: batch_size=1 + logprobs1 = torch.randn(1, vocab_size, device=DEVICE) + token_ids1 = torch.randint( + 0, vocab_size, (1,), device=DEVICE, dtype=torch.int64 + ) + Sampler.gather_logprobs(logprobs1, num_logprobs, token_ids1) + assert compile_count == 1, f"Expected 1 compile, got {compile_count}" + + # Call 2: batch_size=2 — should NOT recompile + logprobs2 = torch.randn(2, vocab_size, device=DEVICE) + token_ids2 = torch.randint( + 0, vocab_size, (2,), device=DEVICE, dtype=torch.int64 + ) + Sampler.gather_logprobs(logprobs2, num_logprobs, token_ids2) + assert compile_count == 1, ( + f"Recompiled on batch_size 1->2 (0/1 specialization). " + f"Expected 1 compile, got {compile_count}" + ) + + # Call 3: batch_size=8 — should NOT recompile + logprobs3 = torch.randn(8, vocab_size, device=DEVICE) + token_ids3 = torch.randint( + 0, vocab_size, (8,), device=DEVICE, dtype=torch.int64 + ) + Sampler.gather_logprobs(logprobs3, num_logprobs, token_ids3) + assert compile_count == 1, ( + f"Recompiled on batch_size change. Expected 1 compile, got {compile_count}" + ) + finally: + # Restore original function + logprobs_module.batched_count_greater_than = orig_fn + sampler_module.batched_count_greater_than = orig_fn + torch._dynamo.reset() diff --git a/vllm/v1/sample/ops/logprobs.py b/vllm/v1/sample/ops/logprobs.py index cf36d46e1..9cdcfa456 100644 --- a/vllm/v1/sample/ops/logprobs.py +++ b/vllm/v1/sample/ops/logprobs.py @@ -7,7 +7,7 @@ import torch from vllm.platforms import current_platform -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +@torch.compile(backend=current_platform.simple_compile_backend) def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor: """ Counts elements in each row of x that are greater than the corresponding @@ -22,4 +22,6 @@ def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.T Returns: torch.Tensor: A 1D tensor of shape (batch_size,) with the counts. """ + torch._check(x.shape[0] >= 1) + torch._check(x.shape[0] == values.shape[0]) return (x >= values).sum(-1) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index acb969360..534135135 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -324,6 +324,12 @@ class Sampler(nn.Module): token_logprobs = logprobs.gather(-1, token_ids) # Compute the ranks of the actual token. + # Avoid 0/1 specialization recompile on the batch dimension + # of the compiled batched_count_greater_than. mark_unbacked makes + # the size fully symbolic so dynamo doesn't specialize when + # batch_size transitions from 1 to >=2. + torch._dynamo.decorators.mark_unbacked(logprobs, 0) + torch._dynamo.decorators.mark_unbacked(token_logprobs, 0) token_ranks = batched_count_greater_than(logprobs, token_logprobs) # Concatenate together with the topk.