[Core] Separate out attention metadata building logic from prepare inputs (#26764)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-11-09 13:51:43 -05:00
committed by GitHub
parent 289eb6c537
commit 636efd10a5

View File

@@ -1054,7 +1054,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_encoder_seq_lens( def _get_encoder_seq_lens(
self, self,
scheduler_output: "SchedulerOutput", scheduled_encoder_inputs: dict[str, list[int]],
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
num_reqs: int, num_reqs: int,
) -> np.ndarray | None: ) -> np.ndarray | None:
@@ -1064,31 +1064,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Build encoder_seq_lens array mapping request indices to # Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch # encoder lengths for inputs scheduled in this batch
encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32) encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
for req_id in scheduler_output.scheduled_encoder_inputs: for req_id in scheduled_encoder_inputs:
req_index = self.input_batch.req_id_to_index[req_id] req_index = self.input_batch.req_id_to_index[req_id]
encoder_seq_lens[req_index] = self.max_encoder_len encoder_seq_lens[req_index] = self.max_encoder_len
return encoder_seq_lens return encoder_seq_lens
def _prepare_inputs( def _prepare_inputs(
self, scheduler_output: "SchedulerOutput" self,
scheduler_output: "SchedulerOutput",
num_scheduled_tokens: np.ndarray,
max_num_scheduled_tokens: int,
) -> tuple[ ) -> tuple[
PerLayerAttnMetadata,
torch.Tensor, torch.Tensor,
SpecDecodeMetadata | None, SpecDecodeMetadata | None,
np.ndarray,
CommonAttentionMetadata | None,
int,
UBatchSlices | None, UBatchSlices | None,
torch.Tensor | None, torch.Tensor | None,
bool,
]: ]:
""" """
:return: tuple[ :return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
logits_indices, spec_decode_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata, ubatch_slices, num_tokens_across_dp,
max_num_scheduled_tokens, use_cascade_attn
] ]
""" """
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -1100,12 +1096,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This way, we can overlap the copy with the following CPU operations. # This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit_block_table(num_reqs) self.input_batch.block_table.commit_block_table(num_reqs)
# Get the number of scheduled tokens for each request.
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Get request indices. # Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
@@ -1232,8 +1222,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Fill unused with 0 for full cuda graph mode. # Fill unused with 0 for full cuda graph mode.
self.seq_lens.np[num_reqs:].fill(0) self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu() self.seq_lens.copy_to_gpu()
seq_lens = self.seq_lens.gpu[:num_reqs]
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
num_tokens_np = np.array(num_tokens, dtype=np.int32) num_tokens_np = np.array(num_tokens, dtype=np.int32)
@@ -1305,11 +1293,46 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
self.num_decode_draft_tokens.copy_to_gpu() self.num_decode_draft_tokens.copy_to_gpu()
logits_indices_padded = None # Hot-Swap lora model
if self.cache_config.kv_sharing_fast_prefill: if self.lora_config:
logits_indices_padded = self._prepare_kv_sharing_fast_prefill( assert (
logits_indices np.sum(num_sampled_tokens)
<= self.vllm_config.scheduler_config.max_num_batched_tokens
) )
self.set_active_loras(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
return (
logits_indices,
spec_decode_metadata,
ubatch_slices,
num_tokens_across_dp,
)
def _build_attention_metadata(
self,
total_num_scheduled_tokens: int,
max_num_scheduled_tokens: int,
num_reqs: int,
ubatch_slices: UBatchSlices | None = None,
logits_indices: torch.Tensor | None = None,
use_spec_decode: bool = False,
for_cudagraph_capture: bool = False,
scheduled_encoder_inputs: dict[str, list[int]] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
logits_indices_padded = None
num_logits_indices = 0
if logits_indices is not None:
num_logits_indices = logits_indices.size(0)
if self.cache_config.kv_sharing_fast_prefill:
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices
)
# update seq_lens of decode reqs under DCP. # update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
@@ -1324,15 +1347,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata: PerLayerAttnMetadata = {} attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None: if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))] attn_metadata = [dict() for _ in range(len(ubatch_slices))]
use_cascade_attn = False
# Used in the below loop. # Used in the below loop
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
seq_lens = self.seq_lens.gpu[:num_reqs]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs] seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs :num_reqs
] ]
dcp_local_seq_lens = (
self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None
)
spec_decode_common_attn_metadata = None spec_decode_common_attn_metadata = None
if for_cudagraph_capture:
# For some attention backends (e.g. FA) with sliding window models we need
# to make sure the backend see a max_seq_len that is larger to the sliding
# window size when capturing to make sure the correct kernel is selected.
max_seq_len = self.max_model_len
else:
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
if use_spec_decode: if use_spec_decode:
self.num_accepted_tokens.np[:num_reqs] = ( self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs] self.input_batch.num_accepted_tokens_cpu[:num_reqs]
@@ -1342,14 +1378,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare the attention metadata for each KV cache group and make layers # Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata. # in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate( for kv_cache_gid, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups self.kv_cache_config.kv_cache_groups
): ):
encoder_seq_lens = self._get_encoder_seq_lens( encoder_seq_lens = self._get_encoder_seq_lens(
scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs scheduled_encoder_inputs or {},
kv_cache_group.kv_cache_spec,
num_reqs,
) )
if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to # Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them. # create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros( blk_table_tensor = torch.zeros(
@@ -1362,18 +1400,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )
num_common_prefix_blocks = 0
else: else:
blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table = self.input_batch.block_table[kv_cache_gid]
blk_table_tensor = blk_table.get_device_tensor(num_reqs) blk_table_tensor = blk_table.get_device_tensor(num_reqs)
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
# Fill unused with -1. Needed for reshape_and_cache in full cuda # Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. # graph mode.
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[
kv_cache_group_id
]
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
@@ -1388,35 +1422,26 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
block_table_tensor=blk_table_tensor, block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
logits_indices_padded=logits_indices_padded, logits_indices_padded=logits_indices_padded,
num_logits_indices=logits_indices.size(0), num_logits_indices=num_logits_indices,
causal=True, causal=True,
encoder_seq_lens=encoder_seq_lens, encoder_seq_lens=encoder_seq_lens,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] dcp_local_seq_lens=dcp_local_seq_lens,
if self.dcp_world_size > 1
else None,
) )
if self.speculative_config and spec_decode_common_attn_metadata is None: if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
if ( if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
self.drafter.attn_layer_names[0]
in kv_cache_group_spec.layer_names
):
spec_decode_common_attn_metadata = common_attn_metadata spec_decode_common_attn_metadata = common_attn_metadata
else: else:
spec_decode_common_attn_metadata = common_attn_metadata spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]: for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]):
# Prepare for cascade attention if enabled & beneficial. cascade_attn_prefix_len = (
common_prefix_len = 0 cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens
else 0
)
builder = attn_group.get_metadata_builder() builder = attn_group.get_metadata_builder()
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
num_common_prefix_blocks,
attn_group.kv_cache_spec,
builder,
)
extra_attn_metadata_args = {} extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
@@ -1434,51 +1459,69 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for ubid, common_attn_metadata in enumerate( for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list common_attn_metadata_list
): ):
attn_metadata_i = attn_group.get_metadata_builder( builder = attn_group.get_metadata_builder(ubatch_id=ubid)
ubatch_id=ubid if for_cudagraph_capture:
).build( attn_metadata_i = builder.build_for_cudagraph_capture(
common_prefix_len=common_prefix_len, common_attn_metadata
common_attn_metadata=common_attn_metadata, )
) else:
for layer_name in kv_cache_group_spec.layer_names: attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_group.layer_names:
assert type(attn_metadata) is list assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i attn_metadata[ubid][layer_name] = attn_metadata_i
else: else:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata_i = builder.build( if for_cudagraph_capture:
common_prefix_len=common_prefix_len, attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata=common_attn_metadata, common_attn_metadata
**extra_attn_metadata_args, )
) else:
use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False) attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
for layer_name in attn_group.layer_names: for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
# disable cascade attention when DBO return attn_metadata, spec_decode_common_attn_metadata
if ubatch_slices is not None:
use_cascade_attn = False
# Hot-Swap lora model def _compute_cascade_attn_prefix_lens(
if self.lora_config: self,
assert ( num_scheduled_tokens: np.ndarray,
np.sum(num_sampled_tokens) num_common_prefix_blocks: list[int],
<= self.vllm_config.scheduler_config.max_num_batched_tokens ) -> list[list[int]] | None:
) """
self.set_active_loras( :return: Optional[cascade_attn_prefix_lens]
self.input_batch, num_scheduled_tokens, num_sampled_tokens cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
) None if we should not use cascade attention
"""
return ( use_cascade_attn = False
attn_metadata, num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)
logits_indices, cascade_attn_prefix_lens: list[list[int]] = [
spec_decode_metadata, [] for _ in range(num_kv_cache_groups)
num_scheduled_tokens, ]
spec_decode_common_attn_metadata,
max_num_scheduled_tokens, for kv_cache_gid in range(num_kv_cache_groups):
ubatch_slices, for attn_group in self.attn_groups[kv_cache_gid]:
num_tokens_across_dp, if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec):
use_cascade_attn, cascade_attn_prefix_len = 0
) else:
# 0 if cascade attention should not be used
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
num_common_prefix_blocks[kv_cache_gid],
attn_group.kv_cache_spec,
attn_group.get_metadata_builder(),
)
cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len)
use_cascade_attn |= cascade_attn_prefix_len > 0
return cascade_attn_prefix_lens if use_cascade_attn else None
def _compute_cascade_attn_prefix_len( def _compute_cascade_attn_prefix_len(
self, self,
@@ -1504,6 +1547,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Returns: Returns:
int: Length of common prefix in tokens. int: Length of common prefix in tokens.
""" """
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
if common_prefix_len == 0: if common_prefix_len == 0:
# Common case. # Common case.
@@ -2497,18 +2541,48 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"it when the requests need prompt logprobs" "it when the requests need prompt logprobs"
) )
# Prepare the decoder inputs. num_reqs = self.input_batch.num_reqs
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
( (
attn_metadata,
logits_indices, logits_indices,
spec_decode_metadata, spec_decode_metadata,
num_scheduled_tokens_np,
spec_decode_common_attn_metadata,
max_query_len,
ubatch_slices, ubatch_slices,
num_tokens_across_dp, num_tokens_across_dp,
use_cascade_attn, ) = self._prepare_inputs(
) = self._prepare_inputs(scheduler_output) scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
)
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and ubatch_slices is None:
# Pre-compute cascade attention prefix lengths
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np,
scheduler_output.num_common_prefix_blocks,
)
# TODO(lucas): move cudagraph dispatching here:
# https://github.com/vllm-project/vllm/issues/23789
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
attn_metadata, spec_decode_common_attn_metadata = (
self._build_attention_metadata(
total_num_scheduled_tokens=total_num_scheduled_tokens,
max_num_scheduled_tokens=max_num_scheduled_tokens,
num_reqs=num_reqs,
ubatch_slices=ubatch_slices,
logits_indices=logits_indices,
use_spec_decode=use_spec_decode,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
)
dp_rank = self.parallel_config.data_parallel_rank dp_rank = self.parallel_config.data_parallel_rank
if ubatch_slices: if ubatch_slices:
@@ -2532,16 +2606,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output, num_input_tokens, intermediate_tensors scheduler_output, num_input_tokens, intermediate_tensors
) )
uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( uniform_decode = (
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len max_num_scheduled_tokens == self.uniform_decode_query_len
) ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
batch_descriptor = BatchDescriptor( batch_descriptor = BatchDescriptor(
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
) )
cudagraph_runtime_mode, batch_descriptor = ( cudagraph_runtime_mode, batch_descriptor = (
self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) self.cudagraph_dispatcher.dispatch(
batch_descriptor,
use_cascade_attn=cascade_attn_prefix_lens is not None,
)
) )
# Set cudagraph mode to none if calc_kv_scales is true. # Set cudagraph mode to none if calc_kv_scales is true.
@@ -3437,10 +3514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# If force_attention is True, we always capture attention. Otherwise, # If force_attention is True, we always capture attention. Otherwise,
# it only happens for cudagraph_runtime_mode=FULL. # it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
if create_mixed_batch: if create_mixed_batch:
# In the mixed batch mode (used for FI warmup), we use # In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster. # shorter sequence lengths to run faster.
@@ -3456,55 +3529,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
for kv_cache_group_id, kv_cache_group_spec in enumerate( attn_metadata, _ = self._build_attention_metadata(
self.kv_cache_config.kv_cache_groups total_num_scheduled_tokens=num_tokens,
): max_num_scheduled_tokens=max_query_len,
common_attn_metadata = CommonAttentionMetadata( num_reqs=num_reqs,
query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], ubatch_slices=ubatch_slices,
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], for_cudagraph_capture=True,
seq_lens=self.seq_lens.gpu[:num_reqs], )
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs
],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=self.max_model_len,
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id
].get_device_tensor(num_reqs),
slot_mapping=self.input_batch.block_table[
kv_cache_group_id
].slot_mapping.gpu[:num_tokens],
causal=True,
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
if self.dcp_world_size > 1
else None,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
if ubatch_slices is not None:
common_attn_metadata_list = split_attn_metadata(
ubatch_slices, common_attn_metadata
)
for ubid, common_attn_metadata in enumerate(
common_attn_metadata_list
):
assert common_attn_metadata.max_query_len == 1
attn_metadata_i = attn_group.get_metadata_builder(
ubatch_id=ubid
).build_for_cudagraph_capture(common_attn_metadata)
for layer_name in attn_group.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
assert type(attn_metadata) is dict
metadata_builder = attn_group.get_metadata_builder()
attn_metadata_i = metadata_builder.build_for_cudagraph_capture(
common_attn_metadata
)
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora( with self.maybe_dummy_run_with_lora(
self.lora_config, self.lora_config,
@@ -4478,9 +4509,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
list[int]: List of kernel block sizes for each cache group. list[int]: List of kernel block sizes for each cache group.
""" """
kernel_block_sizes = [] kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate( for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
kv_cache_config.kv_cache_groups
):
kv_cache_spec = kv_cache_group.kv_cache_spec kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type, # All layers in the UniformTypeKVCacheSpecs have the same type,
@@ -4492,7 +4521,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This is an attention backend that supports virtual # This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from # block splitting. Get the supported block sizes from
# all backends in the group. # all backends in the group.
attn_groups = self.attn_groups[kv_cache_group_id] attn_groups = self.attn_groups[kv_cache_gid]
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
selected_kernel_size = self.select_common_block_size( selected_kernel_size = self.select_common_block_size(
kv_manager_block_size, attn_groups kv_manager_block_size, attn_groups