[BugFix] Fix assert batch_descriptor.num_tokens == num_tokens_padded (#30173)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-12-09 10:36:12 -05:00
committed by GitHub
parent 5dcd593baf
commit 56037dfa2f
6 changed files with 65 additions and 33 deletions

View File

@@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode
# 4. disable_full should have a fall back mode (e.g., cascade attention)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE