[V1][Core] Fix memory issue with logits & sampling (#13776)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2025-03-08 06:11:04 -08:00
committed by GitHub
parent 0b7f06b447
commit 8d5aa466fb
3 changed files with 69 additions and 29 deletions

View File

@@ -142,7 +142,16 @@ def test_end_to_end(model: str, use_v1: bool):
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
# now the memory usage is mostly cudagraph memory pool,
# and it should be less than the model weights (1B model, 2GiB weights)
assert used_bytes < 2 * GiB_bytes
# NOTE: In V1, the memory buffer for logits (max_num_reqs x vocab_size)
# is captured but cannot be releasesd from PyTorch due to a known bug,
# therefore high memory usage after `llm.sleep` is called is expected.
# FIXME(youkaichao & ywang96): Fix memory buffer issue with sleep mode
# in V1.
if use_v1:
assert used_bytes < 7 * GiB_bytes
else:
assert used_bytes < 2 * GiB_bytes
llm.wake_up()
output2 = llm.generate(prompt, sampling_params)