[v1] AttentionMetadata for each layer (#17394)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
|
||||
check_use_alibi, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
KVCacheConfig, KVCacheSpec,
|
||||
@@ -157,9 +158,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Sampler
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialization
|
||||
# Lazy initializations
|
||||
# self.model: nn.Module # Set after load_model
|
||||
# Initialize in initialize_kv_cache
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
|
||||
@@ -488,7 +492,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[FlashAttentionMetadata, torch.Tensor,
|
||||
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
|
||||
Optional[SpecDecodeMetadata]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@@ -585,20 +589,39 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
self.device, non_blocking=True)
|
||||
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
|
||||
non_blocking=True)
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
)
|
||||
attn_metadata: dict[str, FlashAttentionMetadata] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
# NOTE(Chen): there is exactly one KV cache group that contains all
|
||||
# attetnion layers in the model for now, so the current logic for
|
||||
# getting attn_metadata is not related to kv_cache_group information.
|
||||
# Will extend this part to support multiple KV cache groups later.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
|
||||
attn_metadata_i = self.attn_metadata_builder.build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
@@ -608,7 +631,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = attn_metadata.query_start_loc[1:] - 1
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
@@ -1230,6 +1253,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
@@ -1241,8 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = attn_metadata.slot_mapping
|
||||
cu_num_tokens = attn_metadata.query_start_loc
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping
|
||||
cu_num_tokens = eagle_attn_metadata.query_start_loc
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
@@ -1256,7 +1280,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device=self.device,
|
||||
)
|
||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||
attn_metadata.query_start_loc,
|
||||
eagle_attn_metadata.query_start_loc,
|
||||
num_rejected_tokens,
|
||||
)
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
@@ -1266,7 +1290,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||
target_slot_mapping = eagle_attn_metadata.slot_mapping[
|
||||
token_indices]
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
@@ -1275,7 +1300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
target_slot_mapping=target_slot_mapping,
|
||||
next_token_ids=next_token_ids,
|
||||
cu_num_tokens=cu_num_tokens,
|
||||
block_table=attn_metadata.block_table,
|
||||
block_table=eagle_attn_metadata.block_table,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
@@ -1708,6 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
raise NotImplementedError(
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
|
||||
@@ -588,7 +588,14 @@ class TPUModelRunner:
|
||||
# Padded to avoid recompiling when `num_reqs` varies.
|
||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||
logits_indices = logits_indices.to(self.device)
|
||||
return attn_metadata, logits_indices, padded_num_reqs
|
||||
|
||||
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention).keys()
|
||||
per_layer_attn_metadata = {
|
||||
layer_name: attn_metadata
|
||||
for layer_name in layer_names
|
||||
}
|
||||
return per_layer_attn_metadata, logits_indices, padded_num_reqs
|
||||
|
||||
def _scatter_placeholders(
|
||||
self,
|
||||
@@ -956,7 +963,14 @@ class TPUModelRunner:
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
layer_names = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention).keys()
|
||||
per_layer_attn_metadata = {
|
||||
layer_name: attn_metadata
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0):
|
||||
out = self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
Reference in New Issue
Block a user