[V1][Core] min_p sampling support (#13191)
Signed-off-by: Aoyu <aoyuzhan@amazon.com> Co-authored-by: Aoyu <aoyuzhan@amazon.com>
This commit is contained in:
@@ -81,6 +81,8 @@ def _create_default_sampling_metadata(
|
||||
top_k=torch.empty(batch_size, ),
|
||||
no_top_p=True,
|
||||
no_top_k=True,
|
||||
min_p=torch.empty(batch_size, ),
|
||||
no_min_p=True,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
@@ -336,6 +338,46 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
non_penalized_token_id in output_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("min_p", [0.0, 0.1])
|
||||
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
|
||||
"""
|
||||
Tests that when min_p is applied, tokens with probability below
|
||||
min_p * max_prob are masked with -inf.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
|
||||
# Create one dominant token per batch
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, 0] = 10.0 # High logit for first token
|
||||
fake_logits[i, 1:] = 1e-2 # Others remain low
|
||||
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
|
||||
# Configure min_p parameters
|
||||
sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device)
|
||||
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p)
|
||||
logits = logits.cpu()
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if token_id == 0:
|
||||
# Dominant token should always be unmasked
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
else:
|
||||
if min_p > 0.0:
|
||||
# Non-dominant tokens should be masked when min_p > 0
|
||||
assert logits[batch_idx][token_id] == -float("inf")
|
||||
else:
|
||||
# No masking when min_p is 0
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
|
||||
|
||||
Reference in New Issue
Block a user