[BugFix] Fix assert batch_descriptor.num_tokens == num_tokens_padded (#30173)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
assert key == BatchDescriptor(num_tokens=15)
|
||||
|
||||
# 4. Cascade attention should have a fall back mode
|
||||
# 4. disable_full should have a fall back mode (e.g., cascade attention)
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
|
||||
rt_mode, key = dispatcher.dispatch(
|
||||
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
|
||||
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
|
||||
)
|
||||
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
Reference in New Issue
Block a user