[Bugfix] Fix pooling non-determinism from pinned prompt_lens aliasing (#37775)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-21 22:22:24 -05:00
committed by GitHub
parent e78bc74268
commit 66f927f205
2 changed files with 63 additions and 1 deletions

View File

@@ -378,3 +378,65 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch)
def _construct_pooling_request(req_id_suffix: int):
from vllm.pooling_params import PoolingParams
prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10, MAX_PROMPT_SIZE))
]
return CachedRequestState(
req_id=f"pool_req_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=None,
pooling_params=PoolingParams(task="classify"),
mm_features=[],
block_ids=([],),
generator=None,
num_computed_tokens=0,
output_token_ids=[],
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_pooling_prompt_lens_not_aliased(device: str):
"""Verify that prompt_lens in PoolingMetadata does not share memory
with the internal num_prompt_tokens pinned buffer. Guards against possible
non-determinism in pooling metadata due to mutations to the internal buffer.
"""
batch_size = 4
input_batch = InputBatch(
max_num_reqs=batch_size * 2,
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
max_num_batched_tokens=batch_size * (MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS),
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=VOCAB_SIZE,
block_sizes=[16],
kernel_block_sizes=[16],
is_pooling_model=True,
)
reqs = []
# Add requests
for i in range(batch_size):
req = _construct_pooling_request(i)
input_batch.add_request(req)
reqs.append(req)
input_batch.refresh_metadata()
# prompt_lens must be a snapshot
metadata = input_batch.get_pooling_metadata()
prompt_lens_snapshot = metadata.prompt_lens.clone()
# Mutate the internal buffer (simulates next batch adding new requests)
input_batch.num_prompt_tokens_cpu_tensor.fill_(999)
# prompt_lens must be unaffected by the mutation
assert torch.equal(metadata.prompt_lens, prompt_lens_snapshot), (
"prompt_lens shares memory with internal pinned buffer; "
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
)