[BugFix] Fix potential cuda-graph IMA (#21196)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -59,11 +59,6 @@ class CommonAttentionMetadata:
|
|||||||
block_table_tensor: torch.Tensor
|
block_table_tensor: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
|
||||||
# mode.
|
|
||||||
self.slot_mapping[self.num_actual_tokens:].fill_(-1)
|
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar("M")
|
M = TypeVar("M")
|
||||||
|
|
||||||
|
|||||||
@@ -684,7 +684,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
# Fill unused with -1. Needed for reshape_and_cache
|
# Fill unused with 0 for full cuda graph mode.
|
||||||
self.seq_lens[num_reqs:].fill_(0)
|
self.seq_lens[num_reqs:].fill_(0)
|
||||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||||
# like FlashAttention requires that
|
# like FlashAttention requires that
|
||||||
@@ -704,6 +704,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||||
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
|
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
|
||||||
|
|
||||||
|
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||||
|
# graph mode.
|
||||||
|
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||||
|
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||||
|
|||||||
Reference in New Issue
Block a user