[Bugfix] Fix pooling non-determinism from pinned prompt_lens aliasing (#37775)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -378,3 +378,65 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
|
||||
ref_input_batch.refresh_metadata()
|
||||
|
||||
_compare_objs(input_batch, ref_input_batch)
|
||||
|
||||
|
||||
def _construct_pooling_request(req_id_suffix: int):
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
prompt_token_ids = [
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(10, MAX_PROMPT_SIZE))
|
||||
]
|
||||
return CachedRequestState(
|
||||
req_id=f"pool_req_{req_id_suffix}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=None,
|
||||
pooling_params=PoolingParams(task="classify"),
|
||||
mm_features=[],
|
||||
block_ids=([],),
|
||||
generator=None,
|
||||
num_computed_tokens=0,
|
||||
output_token_ids=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_pooling_prompt_lens_not_aliased(device: str):
|
||||
"""Verify that prompt_lens in PoolingMetadata does not share memory
|
||||
with the internal num_prompt_tokens pinned buffer. Guards against possible
|
||||
non-determinism in pooling metadata due to mutations to the internal buffer.
|
||||
"""
|
||||
batch_size = 4
|
||||
input_batch = InputBatch(
|
||||
max_num_reqs=batch_size * 2,
|
||||
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
|
||||
max_num_batched_tokens=batch_size * (MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS),
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=VOCAB_SIZE,
|
||||
block_sizes=[16],
|
||||
kernel_block_sizes=[16],
|
||||
is_pooling_model=True,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
# Add requests
|
||||
for i in range(batch_size):
|
||||
req = _construct_pooling_request(i)
|
||||
input_batch.add_request(req)
|
||||
reqs.append(req)
|
||||
input_batch.refresh_metadata()
|
||||
|
||||
# prompt_lens must be a snapshot
|
||||
metadata = input_batch.get_pooling_metadata()
|
||||
prompt_lens_snapshot = metadata.prompt_lens.clone()
|
||||
|
||||
# Mutate the internal buffer (simulates next batch adding new requests)
|
||||
input_batch.num_prompt_tokens_cpu_tensor.fill_(999)
|
||||
|
||||
# prompt_lens must be unaffected by the mutation
|
||||
assert torch.equal(metadata.prompt_lens, prompt_lens_snapshot), (
|
||||
"prompt_lens shares memory with internal pinned buffer; "
|
||||
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
|
||||
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user