[Core] Separate out attention metadata building logic from prepare inputs (#26764)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -1054,7 +1054,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def _get_encoder_seq_lens(
|
def _get_encoder_seq_lens(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduled_encoder_inputs: dict[str, list[int]],
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
num_reqs: int,
|
num_reqs: int,
|
||||||
) -> np.ndarray | None:
|
) -> np.ndarray | None:
|
||||||
@@ -1064,31 +1064,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Build encoder_seq_lens array mapping request indices to
|
# Build encoder_seq_lens array mapping request indices to
|
||||||
# encoder lengths for inputs scheduled in this batch
|
# encoder lengths for inputs scheduled in this batch
|
||||||
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
|
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
|
||||||
for req_id in scheduler_output.scheduled_encoder_inputs:
|
for req_id in scheduled_encoder_inputs:
|
||||||
req_index = self.input_batch.req_id_to_index[req_id]
|
req_index = self.input_batch.req_id_to_index[req_id]
|
||||||
encoder_seq_lens[req_index] = self.max_encoder_len
|
encoder_seq_lens[req_index] = self.max_encoder_len
|
||||||
|
|
||||||
return encoder_seq_lens
|
return encoder_seq_lens
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self, scheduler_output: "SchedulerOutput"
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
num_scheduled_tokens: np.ndarray,
|
||||||
|
max_num_scheduled_tokens: int,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
PerLayerAttnMetadata,
|
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
SpecDecodeMetadata | None,
|
SpecDecodeMetadata | None,
|
||||||
np.ndarray,
|
|
||||||
CommonAttentionMetadata | None,
|
|
||||||
int,
|
|
||||||
UBatchSlices | None,
|
UBatchSlices | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
bool,
|
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
:return: tuple[
|
:return: tuple[
|
||||||
attn_metadata: layer-to-attention_metadata mapping,
|
|
||||||
logits_indices, spec_decode_metadata,
|
logits_indices, spec_decode_metadata,
|
||||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
ubatch_slices, num_tokens_across_dp,
|
||||||
max_num_scheduled_tokens, use_cascade_attn
|
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
@@ -1100,12 +1096,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# This way, we can overlap the copy with the following CPU operations.
|
# This way, we can overlap the copy with the following CPU operations.
|
||||||
self.input_batch.block_table.commit_block_table(num_reqs)
|
self.input_batch.block_table.commit_block_table(num_reqs)
|
||||||
|
|
||||||
# Get the number of scheduled tokens for each request.
|
|
||||||
req_ids = self.input_batch.req_ids
|
|
||||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
|
||||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
|
||||||
max_num_scheduled_tokens = max(tokens)
|
|
||||||
|
|
||||||
# Get request indices.
|
# Get request indices.
|
||||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||||
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
|
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
|
||||||
@@ -1232,8 +1222,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Fill unused with 0 for full cuda graph mode.
|
# Fill unused with 0 for full cuda graph mode.
|
||||||
self.seq_lens.np[num_reqs:].fill(0)
|
self.seq_lens.np[num_reqs:].fill(0)
|
||||||
self.seq_lens.copy_to_gpu()
|
self.seq_lens.copy_to_gpu()
|
||||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
|
||||||
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
|
||||||
|
|
||||||
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
|
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
|
||||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||||
@@ -1305,11 +1293,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
||||||
self.num_decode_draft_tokens.copy_to_gpu()
|
self.num_decode_draft_tokens.copy_to_gpu()
|
||||||
|
|
||||||
logits_indices_padded = None
|
# Hot-Swap lora model
|
||||||
if self.cache_config.kv_sharing_fast_prefill:
|
if self.lora_config:
|
||||||
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
assert (
|
||||||
logits_indices
|
np.sum(num_sampled_tokens)
|
||||||
|
<= self.vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
)
|
)
|
||||||
|
self.set_active_loras(
|
||||||
|
self.input_batch, num_scheduled_tokens, num_sampled_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
logits_indices,
|
||||||
|
spec_decode_metadata,
|
||||||
|
ubatch_slices,
|
||||||
|
num_tokens_across_dp,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_attention_metadata(
|
||||||
|
self,
|
||||||
|
total_num_scheduled_tokens: int,
|
||||||
|
max_num_scheduled_tokens: int,
|
||||||
|
num_reqs: int,
|
||||||
|
ubatch_slices: UBatchSlices | None = None,
|
||||||
|
logits_indices: torch.Tensor | None = None,
|
||||||
|
use_spec_decode: bool = False,
|
||||||
|
for_cudagraph_capture: bool = False,
|
||||||
|
scheduled_encoder_inputs: dict[str, list[int]] | None = None,
|
||||||
|
cascade_attn_prefix_lens: list[list[int]] | None = None,
|
||||||
|
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
|
||||||
|
"""
|
||||||
|
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
|
||||||
|
"""
|
||||||
|
logits_indices_padded = None
|
||||||
|
num_logits_indices = 0
|
||||||
|
if logits_indices is not None:
|
||||||
|
num_logits_indices = logits_indices.size(0)
|
||||||
|
if self.cache_config.kv_sharing_fast_prefill:
|
||||||
|
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||||||
|
logits_indices
|
||||||
|
)
|
||||||
|
|
||||||
# update seq_lens of decode reqs under DCP.
|
# update seq_lens of decode reqs under DCP.
|
||||||
if self.dcp_world_size > 1:
|
if self.dcp_world_size > 1:
|
||||||
@@ -1324,15 +1347,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_metadata: PerLayerAttnMetadata = {}
|
attn_metadata: PerLayerAttnMetadata = {}
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||||
use_cascade_attn = False
|
|
||||||
|
|
||||||
# Used in the below loop.
|
# Used in the below loop
|
||||||
|
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
|
||||||
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
|
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
|
||||||
|
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||||
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
|
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
|
||||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||||
:num_reqs
|
:num_reqs
|
||||||
]
|
]
|
||||||
|
dcp_local_seq_lens = (
|
||||||
|
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
|
||||||
|
)
|
||||||
spec_decode_common_attn_metadata = None
|
spec_decode_common_attn_metadata = None
|
||||||
|
|
||||||
|
if for_cudagraph_capture:
|
||||||
|
# For some attention backends (e.g. FA) with sliding window models we need
|
||||||
|
# to make sure the backend see a max_seq_len that is larger to the sliding
|
||||||
|
# window size when capturing to make sure the correct kernel is selected.
|
||||||
|
max_seq_len = self.max_model_len
|
||||||
|
else:
|
||||||
|
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||||||
|
|
||||||
if use_spec_decode:
|
if use_spec_decode:
|
||||||
self.num_accepted_tokens.np[:num_reqs] = (
|
self.num_accepted_tokens.np[:num_reqs] = (
|
||||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||||
@@ -1342,14 +1378,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Prepare the attention metadata for each KV cache group and make layers
|
# Prepare the attention metadata for each KV cache group and make layers
|
||||||
# in the same group share the same metadata.
|
# in the same group share the same metadata.
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_gid, kv_cache_group in enumerate(
|
||||||
self.kv_cache_config.kv_cache_groups
|
self.kv_cache_config.kv_cache_groups
|
||||||
):
|
):
|
||||||
encoder_seq_lens = self._get_encoder_seq_lens(
|
encoder_seq_lens = self._get_encoder_seq_lens(
|
||||||
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs
|
scheduled_encoder_inputs or {},
|
||||||
|
kv_cache_group.kv_cache_spec,
|
||||||
|
num_reqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec):
|
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||||
# Encoder-only layers do not have KV cache, so we need to
|
# Encoder-only layers do not have KV cache, so we need to
|
||||||
# create a dummy block table and slot mapping for them.
|
# create a dummy block table and slot mapping for them.
|
||||||
blk_table_tensor = torch.zeros(
|
blk_table_tensor = torch.zeros(
|
||||||
@@ -1362,18 +1400,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
num_common_prefix_blocks = 0
|
|
||||||
else:
|
else:
|
||||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
blk_table = self.input_batch.block_table[kv_cache_gid]
|
||||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
|
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
|
||||||
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
|
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
|
||||||
|
|
||||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||||
# graph mode.
|
# graph mode.
|
||||||
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
|
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
|
||||||
num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[
|
|
||||||
kv_cache_group_id
|
|
||||||
]
|
|
||||||
|
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
common_attn_metadata = CommonAttentionMetadata(
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
@@ -1388,35 +1422,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
block_table_tensor=blk_table_tensor,
|
block_table_tensor=blk_table_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
logits_indices_padded=logits_indices_padded,
|
logits_indices_padded=logits_indices_padded,
|
||||||
num_logits_indices=logits_indices.size(0),
|
num_logits_indices=num_logits_indices,
|
||||||
causal=True,
|
causal=True,
|
||||||
encoder_seq_lens=encoder_seq_lens,
|
encoder_seq_lens=encoder_seq_lens,
|
||||||
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
|
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||||
if self.dcp_world_size > 1
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||||
if isinstance(self.drafter, EagleProposer):
|
if isinstance(self.drafter, EagleProposer):
|
||||||
if (
|
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
||||||
self.drafter.attn_layer_names[0]
|
|
||||||
in kv_cache_group_spec.layer_names
|
|
||||||
):
|
|
||||||
spec_decode_common_attn_metadata = common_attn_metadata
|
spec_decode_common_attn_metadata = common_attn_metadata
|
||||||
else:
|
else:
|
||||||
spec_decode_common_attn_metadata = common_attn_metadata
|
spec_decode_common_attn_metadata = common_attn_metadata
|
||||||
|
|
||||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]):
|
||||||
# Prepare for cascade attention if enabled & beneficial.
|
cascade_attn_prefix_len = (
|
||||||
common_prefix_len = 0
|
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
|
||||||
|
if cascade_attn_prefix_lens
|
||||||
|
else 0
|
||||||
|
)
|
||||||
builder = attn_group.get_metadata_builder()
|
builder = attn_group.get_metadata_builder()
|
||||||
if self.cascade_attn_enabled:
|
|
||||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
|
||||||
num_scheduled_tokens,
|
|
||||||
num_common_prefix_blocks,
|
|
||||||
attn_group.kv_cache_spec,
|
|
||||||
builder,
|
|
||||||
)
|
|
||||||
|
|
||||||
extra_attn_metadata_args = {}
|
extra_attn_metadata_args = {}
|
||||||
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
||||||
@@ -1434,51 +1459,69 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for ubid, common_attn_metadata in enumerate(
|
for ubid, common_attn_metadata in enumerate(
|
||||||
common_attn_metadata_list
|
common_attn_metadata_list
|
||||||
):
|
):
|
||||||
attn_metadata_i = attn_group.get_metadata_builder(
|
builder = attn_group.get_metadata_builder(ubatch_id=ubid)
|
||||||
ubatch_id=ubid
|
if for_cudagraph_capture:
|
||||||
).build(
|
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||||
common_prefix_len=common_prefix_len,
|
common_attn_metadata
|
||||||
common_attn_metadata=common_attn_metadata,
|
)
|
||||||
)
|
else:
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
attn_metadata_i = builder.build(
|
||||||
|
common_prefix_len=cascade_attn_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
)
|
||||||
|
for layer_name in kv_cache_group.layer_names:
|
||||||
assert type(attn_metadata) is list
|
assert type(attn_metadata) is list
|
||||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||||
else:
|
else:
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
attn_metadata_i = builder.build(
|
if for_cudagraph_capture:
|
||||||
common_prefix_len=common_prefix_len,
|
attn_metadata_i = builder.build_for_cudagraph_capture(
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata
|
||||||
**extra_attn_metadata_args,
|
)
|
||||||
)
|
else:
|
||||||
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False)
|
attn_metadata_i = builder.build(
|
||||||
|
common_prefix_len=cascade_attn_prefix_len,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
**extra_attn_metadata_args,
|
||||||
|
)
|
||||||
for layer_name in attn_group.layer_names:
|
for layer_name in attn_group.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
# disable cascade attention when DBO
|
return attn_metadata, spec_decode_common_attn_metadata
|
||||||
if ubatch_slices is not None:
|
|
||||||
use_cascade_attn = False
|
|
||||||
|
|
||||||
# Hot-Swap lora model
|
def _compute_cascade_attn_prefix_lens(
|
||||||
if self.lora_config:
|
self,
|
||||||
assert (
|
num_scheduled_tokens: np.ndarray,
|
||||||
np.sum(num_sampled_tokens)
|
num_common_prefix_blocks: list[int],
|
||||||
<= self.vllm_config.scheduler_config.max_num_batched_tokens
|
) -> list[list[int]] | None:
|
||||||
)
|
"""
|
||||||
self.set_active_loras(
|
:return: Optional[cascade_attn_prefix_lens]
|
||||||
self.input_batch, num_scheduled_tokens, num_sampled_tokens
|
cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
|
||||||
)
|
None if we should not use cascade attention
|
||||||
|
"""
|
||||||
|
|
||||||
return (
|
use_cascade_attn = False
|
||||||
attn_metadata,
|
num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)
|
||||||
logits_indices,
|
cascade_attn_prefix_lens: list[list[int]] = [
|
||||||
spec_decode_metadata,
|
[] for _ in range(num_kv_cache_groups)
|
||||||
num_scheduled_tokens,
|
]
|
||||||
spec_decode_common_attn_metadata,
|
|
||||||
max_num_scheduled_tokens,
|
for kv_cache_gid in range(num_kv_cache_groups):
|
||||||
ubatch_slices,
|
for attn_group in self.attn_groups[kv_cache_gid]:
|
||||||
num_tokens_across_dp,
|
if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||||
use_cascade_attn,
|
cascade_attn_prefix_len = 0
|
||||||
)
|
else:
|
||||||
|
# 0 if cascade attention should not be used
|
||||||
|
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||||
|
num_scheduled_tokens,
|
||||||
|
num_common_prefix_blocks[kv_cache_gid],
|
||||||
|
attn_group.kv_cache_spec,
|
||||||
|
attn_group.get_metadata_builder(),
|
||||||
|
)
|
||||||
|
cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len)
|
||||||
|
use_cascade_attn |= cascade_attn_prefix_len > 0
|
||||||
|
|
||||||
|
return cascade_attn_prefix_lens if use_cascade_attn else None
|
||||||
|
|
||||||
def _compute_cascade_attn_prefix_len(
|
def _compute_cascade_attn_prefix_len(
|
||||||
self,
|
self,
|
||||||
@@ -1504,6 +1547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
int: Length of common prefix in tokens.
|
int: Length of common prefix in tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
|
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
|
||||||
if common_prefix_len == 0:
|
if common_prefix_len == 0:
|
||||||
# Common case.
|
# Common case.
|
||||||
@@ -2497,18 +2541,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
"it when the requests need prompt logprobs"
|
"it when the requests need prompt logprobs"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare the decoder inputs.
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
req_ids = self.input_batch.req_ids
|
||||||
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||||
|
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
|
||||||
|
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
|
||||||
|
|
||||||
(
|
(
|
||||||
attn_metadata,
|
|
||||||
logits_indices,
|
logits_indices,
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
num_scheduled_tokens_np,
|
|
||||||
spec_decode_common_attn_metadata,
|
|
||||||
max_query_len,
|
|
||||||
ubatch_slices,
|
ubatch_slices,
|
||||||
num_tokens_across_dp,
|
num_tokens_across_dp,
|
||||||
use_cascade_attn,
|
) = self._prepare_inputs(
|
||||||
) = self._prepare_inputs(scheduler_output)
|
scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
cascade_attn_prefix_lens = None
|
||||||
|
# Disable cascade attention when using microbatching (DBO)
|
||||||
|
if self.cascade_attn_enabled and ubatch_slices is None:
|
||||||
|
# Pre-compute cascade attention prefix lengths
|
||||||
|
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
|
||||||
|
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
|
||||||
|
num_scheduled_tokens_np,
|
||||||
|
scheduler_output.num_common_prefix_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(lucas): move cudagraph dispatching here:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/23789
|
||||||
|
|
||||||
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
|
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
|
attn_metadata, spec_decode_common_attn_metadata = (
|
||||||
|
self._build_attention_metadata(
|
||||||
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
|
max_num_scheduled_tokens=max_num_scheduled_tokens,
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
ubatch_slices=ubatch_slices,
|
||||||
|
logits_indices=logits_indices,
|
||||||
|
use_spec_decode=use_spec_decode,
|
||||||
|
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
|
||||||
|
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
dp_rank = self.parallel_config.data_parallel_rank
|
dp_rank = self.parallel_config.data_parallel_rank
|
||||||
if ubatch_slices:
|
if ubatch_slices:
|
||||||
@@ -2532,16 +2606,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
scheduler_output, num_input_tokens, intermediate_tensors
|
scheduler_output, num_input_tokens, intermediate_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
uniform_decode = (
|
||||||
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len
|
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||||
)
|
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
||||||
batch_descriptor = BatchDescriptor(
|
batch_descriptor = BatchDescriptor(
|
||||||
num_tokens=num_input_tokens,
|
num_tokens=num_input_tokens,
|
||||||
uniform_decode=uniform_decode,
|
uniform_decode=uniform_decode,
|
||||||
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
|
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
|
||||||
)
|
)
|
||||||
cudagraph_runtime_mode, batch_descriptor = (
|
cudagraph_runtime_mode, batch_descriptor = (
|
||||||
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn)
|
self.cudagraph_dispatcher.dispatch(
|
||||||
|
batch_descriptor,
|
||||||
|
use_cascade_attn=cascade_attn_prefix_lens is not None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set cudagraph mode to none if calc_kv_scales is true.
|
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||||
@@ -3437,10 +3514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# If force_attention is True, we always capture attention. Otherwise,
|
# If force_attention is True, we always capture attention. Otherwise,
|
||||||
# it only happens for cudagraph_runtime_mode=FULL.
|
# it only happens for cudagraph_runtime_mode=FULL.
|
||||||
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
attn_metadata = {}
|
|
||||||
if ubatch_slices is not None:
|
|
||||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
|
||||||
|
|
||||||
if create_mixed_batch:
|
if create_mixed_batch:
|
||||||
# In the mixed batch mode (used for FI warmup), we use
|
# In the mixed batch mode (used for FI warmup), we use
|
||||||
# shorter sequence lengths to run faster.
|
# shorter sequence lengths to run faster.
|
||||||
@@ -3456,55 +3529,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
||||||
self.query_start_loc.copy_to_gpu()
|
self.query_start_loc.copy_to_gpu()
|
||||||
|
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
attn_metadata, _ = self._build_attention_metadata(
|
||||||
self.kv_cache_config.kv_cache_groups
|
total_num_scheduled_tokens=num_tokens,
|
||||||
):
|
max_num_scheduled_tokens=max_query_len,
|
||||||
common_attn_metadata = CommonAttentionMetadata(
|
num_reqs=num_reqs,
|
||||||
query_start_loc=self.query_start_loc.gpu[: num_reqs + 1],
|
ubatch_slices=ubatch_slices,
|
||||||
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1],
|
for_cudagraph_capture=True,
|
||||||
seq_lens=self.seq_lens.gpu[:num_reqs],
|
)
|
||||||
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
|
|
||||||
num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
|
|
||||||
:num_reqs
|
|
||||||
],
|
|
||||||
num_reqs=num_reqs,
|
|
||||||
num_actual_tokens=num_tokens,
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
max_seq_len=self.max_model_len,
|
|
||||||
block_table_tensor=self.input_batch.block_table[
|
|
||||||
kv_cache_group_id
|
|
||||||
].get_device_tensor(num_reqs),
|
|
||||||
slot_mapping=self.input_batch.block_table[
|
|
||||||
kv_cache_group_id
|
|
||||||
].slot_mapping.gpu[:num_tokens],
|
|
||||||
causal=True,
|
|
||||||
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
|
|
||||||
if self.dcp_world_size > 1
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
|
||||||
if ubatch_slices is not None:
|
|
||||||
common_attn_metadata_list = split_attn_metadata(
|
|
||||||
ubatch_slices, common_attn_metadata
|
|
||||||
)
|
|
||||||
for ubid, common_attn_metadata in enumerate(
|
|
||||||
common_attn_metadata_list
|
|
||||||
):
|
|
||||||
assert common_attn_metadata.max_query_len == 1
|
|
||||||
attn_metadata_i = attn_group.get_metadata_builder(
|
|
||||||
ubatch_id=ubid
|
|
||||||
).build_for_cudagraph_capture(common_attn_metadata)
|
|
||||||
for layer_name in attn_group.layer_names:
|
|
||||||
assert type(attn_metadata) is list
|
|
||||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
|
||||||
else:
|
|
||||||
assert type(attn_metadata) is dict
|
|
||||||
metadata_builder = attn_group.get_metadata_builder()
|
|
||||||
attn_metadata_i = metadata_builder.build_for_cudagraph_capture(
|
|
||||||
common_attn_metadata
|
|
||||||
)
|
|
||||||
for layer_name in attn_group.layer_names:
|
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(
|
with self.maybe_dummy_run_with_lora(
|
||||||
self.lora_config,
|
self.lora_config,
|
||||||
@@ -4478,9 +4509,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
list[int]: List of kernel block sizes for each cache group.
|
list[int]: List of kernel block sizes for each cache group.
|
||||||
"""
|
"""
|
||||||
kernel_block_sizes = []
|
kernel_block_sizes = []
|
||||||
for kv_cache_group_id, kv_cache_group in enumerate(
|
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||||
kv_cache_config.kv_cache_groups
|
|
||||||
):
|
|
||||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||||
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||||
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||||
@@ -4492,7 +4521,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# This is an attention backend that supports virtual
|
# This is an attention backend that supports virtual
|
||||||
# block splitting. Get the supported block sizes from
|
# block splitting. Get the supported block sizes from
|
||||||
# all backends in the group.
|
# all backends in the group.
|
||||||
attn_groups = self.attn_groups[kv_cache_group_id]
|
attn_groups = self.attn_groups[kv_cache_gid]
|
||||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||||
selected_kernel_size = self.select_common_block_size(
|
selected_kernel_size = self.select_common_block_size(
|
||||||
kv_manager_block_size, attn_groups
|
kv_manager_block_size, attn_groups
|
||||||
|
|||||||
Reference in New Issue
Block a user