[V1][Core] min_p sampling support (#13191)

Signed-off-by: Aoyu <aoyuzhan@amazon.com>
Co-authored-by: Aoyu <aoyuzhan@amazon.com>
This commit is contained in:
Aoyu
2025-02-15 07:50:05 +08:00
committed by GitHub
parent 3bcb8c75da
commit a12934d3ec
4 changed files with 96 additions and 0 deletions

View File

@@ -81,6 +81,8 @@ def _create_default_sampling_metadata(
top_k=torch.empty(batch_size, ),
no_top_p=True,
no_top_k=True,
min_p=torch.empty(batch_size, ),
no_min_p=True,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
@@ -336,6 +338,46 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
non_penalized_token_id in output_tokens)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("min_p", [0.0, 0.1])
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
"""
Tests that when min_p is applied, tokens with probability below
min_p * max_prob are masked with -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
# Create one dominant token per batch
for i in range(batch_size):
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
# Configure min_p parameters
sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device)
sampler = Sampler()
logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id == 0:
# Dominant token should always be unmasked
assert logits[batch_idx][token_id] != -float("inf")
else:
if min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
assert logits[batch_idx][token_id] == -float("inf")
else:
# No masking when min_p is 0
assert logits[batch_idx][token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])