From 4147910f1e893ba69aa86a210c73e02ae8a0dfde Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 17 Jan 2026 20:09:48 -0800 Subject: [PATCH] [Model Runner V2] Move mrope_positions buffer to MRopeState (#32532) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/cudagraph_utils.py | 11 +++++---- vllm/v1/worker/gpu/input_batch.py | 19 ++------------- vllm/v1/worker/gpu/mm/mrope_utils.py | 20 +++++++++++++--- vllm/v1/worker/gpu/model_runner.py | 33 +++++++++++++++++++-------- 4 files changed, 49 insertions(+), 34 deletions(-) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index abcdb69e4..a19e6383b 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -75,16 +75,17 @@ class CudaGraphManager: num_tokens: int, model: nn.Module, input_buffers: InputBuffers, + mrope_positions: torch.Tensor | None, block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, ) -> None: num_reqs = min(num_tokens, self.max_num_reqs) input_ids = input_buffers.input_ids[:num_tokens] - if not self.uses_mrope: - positions = input_buffers.positions[:num_tokens] - else: - positions = input_buffers.mrope_positions[:, :num_tokens] + positions = input_buffers.positions[:num_tokens] + if self.uses_mrope: + assert mrope_positions is not None + positions = mrope_positions[:, :num_tokens] attn_metadata = prepare_inputs_to_capture( num_reqs, num_tokens, @@ -136,6 +137,7 @@ class CudaGraphManager: self, model: nn.Module, input_buffers: InputBuffers, + mrope_positions: torch.Tensor | None, block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, @@ -146,6 +148,7 @@ class CudaGraphManager: self.capture_graph, model=model, input_buffers=input_buffers, + mrope_positions=mrope_positions, block_tables=block_tables, attn_metadata_builders=attn_metadata_builders, kv_cache_config=kv_cache_config, diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 8f9552e3f..00564710c 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -31,19 +31,6 @@ class InputBuffers: ) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) - # NOTE: `mrope_positions` is implemented with one additional dummy - # position on purpose to make it non-contiguous so that it can work - # with torch compile. - # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 - # NOTE: When M-RoPE is enabled, position ids are 3D regardless of - # the modality of inputs. For text-only inputs, each dimension has - # identical position IDs, making M-RoPE functionally equivalent to - # 1D-RoPE. - # See page 5 of https://arxiv.org/abs/2409.12191 - self.mrope_positions = torch.zeros( - (3, max_num_tokens + 1), dtype=torch.int64, device=device - ) - @dataclass class InputBatch: @@ -76,7 +63,7 @@ class InputBatch: # [num_tokens_after_padding] positions: torch.Tensor # [3, num_tokens_after_padding] - mrope_positions: torch.Tensor + mrope_positions: torch.Tensor | None # layer_name -> Metadata attn_metadata: dict[str, Any] @@ -124,8 +111,6 @@ class InputBatch: input_ids = input_buffers.input_ids[:num_tokens].zero_() positions = input_buffers.positions[:num_tokens].zero_() - input_buffers.mrope_positions.zero_() - mrope_positions = input_buffers.mrope_positions[:, :num_tokens] # attn_metadata = defaultdict(lambda: None) logits_indices = query_start_loc[1:] - 1 @@ -146,7 +131,7 @@ class InputBatch: seq_lens=seq_lens, input_ids=input_ids, positions=positions, - mrope_positions=mrope_positions, + mrope_positions=None, attn_metadata=None, # type: ignore logits_indices=logits_indices, cu_num_logits=cu_num_logits, diff --git a/vllm/v1/worker/gpu/mm/mrope_utils.py b/vllm/v1/worker/gpu/mm/mrope_utils.py index c18b9c82e..4c915a5c9 100644 --- a/vllm/v1/worker/gpu/mm/mrope_utils.py +++ b/vllm/v1/worker/gpu/mm/mrope_utils.py @@ -11,10 +11,12 @@ class MRopeState: def __init__( self, max_num_reqs: int, + max_num_tokens: int, max_model_len: int, device: torch.device, ): self.max_num_reqs = max_num_reqs + self.max_num_tokens = max_num_tokens self.max_model_len = max_model_len self.device = device @@ -28,6 +30,19 @@ class MRopeState: ) self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) + # NOTE: `mrope_positions` is implemented with one additional dummy + # position on purpose to make it non-contiguous so that it can work + # with torch compile. + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 + self.mrope_positions = torch.zeros( + (3, max_num_tokens + 1), dtype=torch.int64, device=device + ) + def init_prefill_mrope_positions( self, req_idx: int, @@ -58,12 +73,11 @@ class MRopeState: query_start_loc: torch.Tensor, prefill_lens: torch.Tensor, num_computed_tokens: torch.Tensor, - mrope_positions: torch.Tensor, ) -> None: num_reqs = idx_mapping.shape[0] _prepare_mrope_positions_kernel[(num_reqs,)]( - mrope_positions, - mrope_positions.stride(0), + self.mrope_positions, + self.mrope_positions.stride(0), self.prefill_mrope_positions.gpu, self.prefill_mrope_positions.gpu.stride(0), self.max_model_len, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 6e3eaad4d..6333075ed 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -99,6 +99,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.uses_mrope: self.mrope_states = MRopeState( max_num_reqs=self.max_num_reqs, + max_num_tokens=self.max_num_tokens, max_model_len=self.max_model_len, device=self.device, ) @@ -284,15 +285,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_buffers=self.input_buffers, device=self.device, ) + if self.uses_mrope: + input_batch.mrope_positions = self.mrope_states.mrope_positions[ + :, :num_tokens + ] if not skip_attn: self.prepare_dummy_attn_metadata(input_batch) dp_size = self.parallel_config.data_parallel_size num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens) num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32) - if not self.uses_mrope: - positions = input_batch.positions - else: + positions = input_batch.positions + if self.uses_mrope: positions = input_batch.mrope_positions with ( self.maybe_dummy_run_with_lora( @@ -371,9 +375,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_free_gpu_memory = torch.cuda.mem_get_info()[0] with self.maybe_setup_dummy_loras(self.lora_config): + mrope_positions = None + if self.uses_mrope: + mrope_positions = self.mrope_states.mrope_positions self.cudagraph_manager.capture( model=self.model, input_buffers=self.input_buffers, + mrope_positions=mrope_positions, block_tables=self.block_tables, attn_metadata_builders=self.attn_metadata_builders, kv_cache_config=self.kv_cache_config, @@ -566,7 +574,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): query_start_loc, self.req_states.prefill_len.gpu, self.req_states.num_computed_tokens.gpu, - self.input_buffers.mrope_positions, ) # Some input token ids are directly read from the last sampled tokens @@ -604,9 +611,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding] - mrope_positions = self.input_buffers.mrope_positions[ - :, :num_tokens_after_padding - ] + mrope_positions = None + if self.uses_mrope: + mrope_positions = self.mrope_states.mrope_positions[ + :, :num_tokens_after_padding + ] return InputBatch( req_ids=req_ids, num_reqs=num_reqs, @@ -936,6 +945,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_buffers=self.input_buffers, device=self.device, ) + if self.uses_mrope: + input_batch.mrope_positions = self.mrope_states.mrope_positions[ + :, :num_tokens_after_padding + ] self.prepare_dummy_attn_metadata(input_batch) # Run model. @@ -949,9 +962,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: # Run PyTorch model in eager mode. # TODO(woosuk): Support piecewise CUDA graph. - if not self.uses_mrope: - positions = input_batch.positions - else: + positions = input_batch.positions + if self.uses_mrope: + assert input_batch.mrope_positions is not None positions = input_batch.mrope_positions with set_forward_context( input_batch.attn_metadata,