diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index 468e77113..d9fc4515b 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -7,17 +7,14 @@ import torch from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.v1.attention.backend import ( - AttentionBackend, - AttentionMetadataBuilder, - CommonAttentionMetadata, -) +from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata from vllm.v1.kv_cache_interface import ( AttentionSpec, KVCacheConfig, KVCacheSpec, + UniformTypeKVCacheSpecs, ) -from vllm.v1.worker.utils import bind_kv_cache +from vllm.v1.worker.utils import AttentionGroup, bind_kv_cache def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: @@ -35,29 +32,56 @@ def init_attn_backend( kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device ): attn_backends: dict[str, type[AttentionBackend]] = {} - attn_metadata_builders: list[AttentionMetadataBuilder] = [] - flashinfer_workspace: torch.Tensor | None = None - for kv_cache_group_spec in kv_cache_config.kv_cache_groups: + attn_groups: list[list[AttentionGroup]] = [] + attn_backend_workspace: torch.Tensor | None = None + for kv_cache_group_id, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups + ): layer_names = kv_cache_group_spec.layer_names - any_layer_name = next(iter(layer_names)) layer_type = cast(type[Any], AttentionLayerBase) attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names) - attn_backend = attn_layers[any_layer_name].get_attn_backend() + + group_map: dict[tuple[tuple[str, str], KVCacheSpec], AttentionGroup] = {} + group_order: list[tuple[tuple[str, str], KVCacheSpec]] = [] + for layer_name in layer_names: + attn_backend = attn_layers[layer_name].get_attn_backend() attn_backends[layer_name] = attn_backend - attn_metadata_builder = attn_backend.get_builder_cls()( - kv_cache_group_spec.kv_cache_spec, layer_names, vllm_config, device - ) - attn_metadata_builders.append(attn_metadata_builder) # type: ignore + layer_kv_cache_spec: KVCacheSpec = kv_cache_group_spec.kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] - if attn_backend.get_name() == "FLASHINFER": - if flashinfer_workspace is None: - flashinfer_workspace = attn_metadata_builder._get_workspace_buffer() + key = (attn_backend.full_cls_name(), layer_kv_cache_spec) + if key not in group_map: + group_map[key] = AttentionGroup( + attn_backend, + [layer_name], + layer_kv_cache_spec, + kv_cache_group_id, + ) + group_order.append(key) else: - attn_metadata_builder.set_workspace_buffer(flashinfer_workspace) - return attn_backends, attn_metadata_builders + group_map[key].layer_names.append(layer_name) + + groups = [group_map[key] for key in group_order] + for group in groups: + group.create_metadata_builders( + vllm_config=vllm_config, + device=device, + kernel_block_size=None, + num_metadata_builders=1, + ) + builder = group.get_metadata_builder(0) + if attn_backend_workspace is None: + if hasattr(builder, "_get_workspace_buffer"): + attn_backend_workspace = builder._get_workspace_buffer() + else: + if hasattr(builder, "set_workspace_buffer"): + builder.set_workspace_buffer(attn_backend_workspace) + attn_groups.append(groups) + return attn_backends, attn_groups def _allocate_kv_cache(kv_cache_config: KVCacheConfig, device: torch.device): @@ -144,7 +168,7 @@ def build_slot_mappings_by_layer( def build_attn_metadata( - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], num_reqs: int, num_tokens: int, query_start_loc_gpu: torch.Tensor, @@ -162,8 +186,8 @@ def build_attn_metadata( dcp_local_seq_lens = dcp_local_seq_lens[:num_reqs] attn_metadata: dict[str, Any] = {} - kv_cache_groups = kv_cache_config.kv_cache_groups - for i, kv_cache_spec in enumerate(kv_cache_groups): + num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) + for i in range(num_kv_cache_groups): block_table = block_tables[i] slot_mapping = slot_mappings[i] @@ -181,10 +205,11 @@ def build_attn_metadata( dcp_local_seq_lens=dcp_local_seq_lens, ) - attn_metadata_builder = attn_metadata_builders[i] - metadata = attn_metadata_builder.build( - common_prefix_len=0, common_attn_metadata=common_attn_metadata - ) - for layer_name in kv_cache_spec.layer_names: - attn_metadata[layer_name] = metadata + for attn_group in attn_groups[i]: + attn_metadata_builder = attn_group.get_metadata_builder(0) + metadata = attn_metadata_builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = metadata return attn_metadata diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index e3839894a..7bba7ffb9 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -13,7 +13,6 @@ from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, @@ -22,6 +21,7 @@ from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBuffers +from vllm.v1.worker.utils import AttentionGroup class CudaGraphManager: @@ -83,7 +83,7 @@ class CudaGraphManager: mrope_positions: torch.Tensor | None, inputs_embeds: torch.Tensor | None, block_tables: BlockTables, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, has_lora: bool = False, uniform_decode: bool = False, @@ -116,7 +116,7 @@ class CudaGraphManager: num_tokens, input_buffers, block_tables, - attn_metadata_builders, + attn_groups, self.max_model_len, kv_cache_config, uniform_decode_query_len=( @@ -232,7 +232,7 @@ class CudaGraphManager: mrope_positions: torch.Tensor | None, inputs_embeds: torch.Tensor | None, block_tables: BlockTables, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, has_lora: bool = False, ) -> None: @@ -244,7 +244,7 @@ class CudaGraphManager: mrope_positions=mrope_positions, inputs_embeds=inputs_embeds, block_tables=block_tables, - attn_metadata_builders=attn_metadata_builders, + attn_groups=attn_groups, kv_cache_config=kv_cache_config, has_lora=has_lora, ) @@ -286,6 +286,16 @@ class CudaGraphManager: cudagraph_mode = self.cudagraph_mode.decode_mode() else: cudagraph_mode = self.cudagraph_mode.mixed_mode() + + if ( + cudagraph_mode == CUDAGraphMode.FULL + and cudagraph_size is not None + and cudagraph_size not in self.graphs + ): + # If graph wasn't captured yet, fall back to eager. + # This might happen when the dummy run is called before capture. + cudagraph_mode = CUDAGraphMode.NONE + cudagraph_size = None return cudagraph_mode, cudagraph_size def run_fullgraph(self, num_tokens: int) -> torch.Tensor: @@ -354,7 +364,7 @@ def prepare_inputs_to_capture( num_tokens: int, input_buffers: InputBuffers, block_tables: BlockTables, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], max_model_len: int, kv_cache_config: KVCacheConfig, uniform_decode_query_len: int = 0, @@ -386,7 +396,7 @@ def prepare_inputs_to_capture( ) attn_metadata = build_attn_metadata( - attn_metadata_builders=attn_metadata_builders, + attn_groups=attn_groups, num_reqs=num_reqs, num_tokens=num_tokens, query_start_loc_gpu=query_start_loc, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 37f87d7b6..b909b90ad 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -283,7 +283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): cp_interleave=self.cp_interleave, ) - self.attn_backends, self.attn_metadata_builders = init_attn_backend( + self.attn_backends, self.attn_groups = init_attn_backend( self.kv_cache_config, self.vllm_config, self.device ) check_attention_cp_compatibility(self.vllm_config) @@ -291,7 +291,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # HACK(woosuk) self.speculator.set_attn( self.kv_cache_config, - self.attn_metadata_builders, + self.attn_groups, self.block_tables, ) @@ -305,9 +305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) - # Attention groups are not supported. - self.attn_groups = [] # type: ignore - def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None: block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs) slot_mappings = self.block_tables.get_dummy_slot_mappings( @@ -317,7 +314,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): slot_mappings, self.kv_cache_config ) attn_metadata = build_attn_metadata( - attn_metadata_builders=self.attn_metadata_builders, + attn_groups=self.attn_groups, num_reqs=input_batch.num_reqs, num_tokens=input_batch.num_tokens, query_start_loc_gpu=input_batch.query_start_loc, @@ -477,7 +474,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): mrope_positions=mrope_positions, inputs_embeds=inputs_embeds, block_tables=self.block_tables, - attn_metadata_builders=self.attn_metadata_builders, + attn_groups=self.attn_groups, kv_cache_config=self.kv_cache_config, has_lora=self.lora_config is not None, ) @@ -712,7 +709,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Layer name -> attention metadata. attn_metadata = build_attn_metadata( - attn_metadata_builders=self.attn_metadata_builders, + attn_groups=self.attn_groups, num_reqs=num_reqs, num_tokens=num_tokens, query_start_loc_gpu=query_start_loc, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py index ae7aa4078..c489a172c 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py @@ -7,7 +7,6 @@ import torch from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode -from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.cudagraph_utils import ( @@ -17,6 +16,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import ( ) from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBuffers +from vllm.v1.worker.utils import AttentionGroup class EagleCudaGraphManager: @@ -60,7 +60,7 @@ class EagleCudaGraphManager: generate_fn: Callable, input_buffers: InputBuffers, block_tables: BlockTables, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, ) -> None: assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( @@ -77,7 +77,7 @@ class EagleCudaGraphManager: num_tokens, input_buffers, block_tables, - attn_metadata_builders, + attn_groups, self.max_model_len, kv_cache_config, uniform_decode_query_len=1, @@ -150,7 +150,7 @@ class EagleCudaGraphManager: generate_fn: Callable, input_buffers: InputBuffers, block_tables: BlockTables, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, ) -> None: if self.cudagraph_mode == CUDAGraphMode.NONE: @@ -165,7 +165,7 @@ class EagleCudaGraphManager: generate_fn=generate_fn, input_buffers=input_buffers, block_tables=block_tables, - attn_metadata_builders=attn_metadata_builders, + attn_groups=attn_groups, kv_cache_config=kv_cache_config, ) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 3cd8afee7..6cd13cebf 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -10,7 +10,6 @@ from vllm.config.compilation import CUDAGraphMode from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.triton_utils import tl, triton -from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, @@ -21,6 +20,7 @@ from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model +from vllm.v1.worker.utils import AttentionGroup logger = init_logger(__name__) @@ -78,11 +78,11 @@ class EagleSpeculator: def set_attn( self, kv_cache_config: KVCacheConfig, - attn_metadata_builders: list[AttentionMetadataBuilder], + attn_groups: list[list[AttentionGroup]], block_tables: BlockTables, ) -> None: self.kv_cache_config = kv_cache_config - self.attn_metadata_builders = attn_metadata_builders + self.attn_groups = attn_groups self.block_tables = block_tables @torch.inference_mode() @@ -174,7 +174,7 @@ class EagleSpeculator: self.generate_draft, self.input_buffers, self.block_tables, - self.attn_metadata_builders, + self.attn_groups, self.kv_cache_config, ) @@ -298,7 +298,7 @@ class EagleSpeculator: # FIXME(woosuk): This is UNSAFE!! attn_metadata = build_attn_metadata( - attn_metadata_builders=self.attn_metadata_builders, + attn_groups=self.attn_groups, num_reqs=num_reqs, num_tokens=num_reqs, query_start_loc_gpu=query_start_loc,