From 0a7dd23754ed5e01303e0eb4e64ace5e70251f46 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 12 Jan 2026 13:37:43 -0800 Subject: [PATCH] [Model Runner V2] Add support for M-RoPE (#32143) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/cudagraph_utils.py | 7 +- vllm/v1/worker/gpu/input_batch.py | 23 ++++- vllm/v1/worker/gpu/mm/__init__.py | 0 vllm/v1/worker/gpu/mm/mrope_utils.py | 127 ++++++++++++++++++++++++++ vllm/v1/worker/gpu/model_runner.py | 53 ++++++++++- 5 files changed, 203 insertions(+), 7 deletions(-) create mode 100644 vllm/v1/worker/gpu/mm/__init__.py create mode 100644 vllm/v1/worker/gpu/mm/mrope_utils.py diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index a7c20ec8b..d5095af18 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -25,10 +25,12 @@ class CudaGraphManager: def __init__( self, vllm_config: VllmConfig, + uses_mrope: bool, device: torch.device, ): self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config + self.uses_mrope = uses_mrope self.device = device self.max_model_len = vllm_config.model_config.max_model_len @@ -79,7 +81,10 @@ class CudaGraphManager: ) -> None: num_reqs = min(num_tokens, self.max_num_reqs) input_ids = input_buffers.input_ids[:num_tokens] - positions = input_buffers.positions[:num_tokens] + if not self.uses_mrope: + positions = input_buffers.positions[:num_tokens] + else: + positions = input_buffers.mrope_positions[:, :num_tokens] attn_metadata = prepare_inputs_to_capture( num_reqs, num_tokens, diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 78889d2ad..8f9552e3f 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -31,6 +31,19 @@ class InputBuffers: ) self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) + # NOTE: `mrope_positions` is implemented with one additional dummy + # position on purpose to make it non-contiguous so that it can work + # with torch compile. + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 + self.mrope_positions = torch.zeros( + (3, max_num_tokens + 1), dtype=torch.int64, device=device + ) + @dataclass class InputBatch: @@ -62,6 +75,8 @@ class InputBatch: input_ids: torch.Tensor # [num_tokens_after_padding] positions: torch.Tensor + # [3, num_tokens_after_padding] + mrope_positions: torch.Tensor # layer_name -> Metadata attn_metadata: dict[str, Any] @@ -107,8 +122,11 @@ class InputBatch: input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens query_start_loc = input_buffers.query_start_loc[: num_reqs + 1] - input_ids = input_buffers.input_ids[:num_tokens] - positions = input_buffers.positions[:num_tokens] + input_ids = input_buffers.input_ids[:num_tokens].zero_() + positions = input_buffers.positions[:num_tokens].zero_() + input_buffers.mrope_positions.zero_() + mrope_positions = input_buffers.mrope_positions[:, :num_tokens] + # attn_metadata = defaultdict(lambda: None) logits_indices = query_start_loc[1:] - 1 cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32) @@ -128,6 +146,7 @@ class InputBatch: seq_lens=seq_lens, input_ids=input_ids, positions=positions, + mrope_positions=mrope_positions, attn_metadata=None, # type: ignore logits_indices=logits_indices, cu_num_logits=cu_num_logits, diff --git a/vllm/v1/worker/gpu/mm/__init__.py b/vllm/v1/worker/gpu/mm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/v1/worker/gpu/mm/mrope_utils.py b/vllm/v1/worker/gpu/mm/mrope_utils.py new file mode 100644 index 000000000..c18b9c82e --- /dev/null +++ b/vllm/v1/worker/gpu/mm/mrope_utils.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.model_executor.models.interfaces import SupportsMRoPE +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor + + +class MRopeState: + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + device: torch.device, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.device = device + + # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) + # wasting a lot of CPU memory. + self.prefill_mrope_positions = StagedWriteTensor( + (max_num_reqs, 3 * max_model_len), + dtype=torch.int32, + device=device, + uva_instead_of_gpu=True, + ) + self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) + + def init_prefill_mrope_positions( + self, + req_idx: int, + mrope_model: SupportsMRoPE, + prefill_token_ids: list[int], + mm_features: list, + ) -> None: + prefill_mrope_positions, prefill_mrope_delta = ( + mrope_model.get_mrope_input_positions( + prefill_token_ids, + mm_features, + ) + ) + for i in range(3): + pos = prefill_mrope_positions[i].tolist() + self.prefill_mrope_positions.stage_write( + req_idx, i * self.max_model_len, pos + ) + self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta + + def apply_staged_writes(self) -> None: + self.prefill_mrope_positions.apply_write() + self.prefill_mrope_delta.copy_to_uva() + + def prepare_mrope_positions( + self, + idx_mapping: torch.Tensor, + query_start_loc: torch.Tensor, + prefill_lens: torch.Tensor, + num_computed_tokens: torch.Tensor, + mrope_positions: torch.Tensor, + ) -> None: + num_reqs = idx_mapping.shape[0] + _prepare_mrope_positions_kernel[(num_reqs,)]( + mrope_positions, + mrope_positions.stride(0), + self.prefill_mrope_positions.gpu, + self.prefill_mrope_positions.gpu.stride(0), + self.max_model_len, + self.prefill_mrope_delta.gpu, + idx_mapping, + query_start_loc, + prefill_lens, + num_computed_tokens, + BLOCK_SIZE=1024, + ) + + +@triton.jit +def _prepare_mrope_positions_kernel( + mrope_positions_ptr, + mrope_positions_stride, + prefill_mrope_positions_ptr, + prefill_mrope_positions_stride0, + prefill_mrope_positions_stride1, + prefill_mrope_delta_ptr, + idx_mapping_ptr, + query_start_loc_ptr, + prefill_lens_ptr, + num_computed_tokens_ptr, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + prefill_len = tl.load(prefill_lens_ptr + req_state_idx) + num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) + is_prefill = num_computed < prefill_len + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx) + for i in range(0, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + orig_pos = num_computed + block + + for j in tl.static_range(3): + if is_prefill: + # Read from pre-computed M-RoPE positions. + pos = tl.load( + prefill_mrope_positions_ptr + + req_state_idx * prefill_mrope_positions_stride0 + + j * prefill_mrope_positions_stride1 + + orig_pos, + mask=mask, + ) + else: + # Apply M-RoPE delta. + pos = orig_pos + mrope_delta + tl.store( + mrope_positions_ptr + j * mrope_positions_stride + query_start + block, + pos, + mask=mask, + ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 06dc7467f..7300357a1 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -47,6 +47,7 @@ from vllm.v1.worker.gpu.input_batch import ( prepare_pos_seq_lens, prepare_prefill_inputs, ) +from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu.sample.output import SamplerOutput @@ -94,6 +95,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.max_num_reqs = self.scheduler_config.max_num_seqs self.inputs_embeds_size = self.model_config.get_inputs_embeds_size() + # Multimodal + self.uses_mrope = self.model_config.uses_mrope + if self.uses_mrope: + self.mrope_states = MRopeState( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + device=self.device, + ) + self.use_async_scheduling = self.scheduler_config.async_scheduling self.output_copy_stream = torch.cuda.Stream(self.device) self.output_copy_event = torch.cuda.Event() @@ -132,7 +142,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) # CUDA graphs. - self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device) + self.cudagraph_manager = CudaGraphManager( + self.vllm_config, self.uses_mrope, self.device + ) # Structured outputs worker. self.structured_outputs_worker = StructuredOutputsWorker( max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1), @@ -268,6 +280,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dp_size = self.parallel_config.data_parallel_size num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens) num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32) + if not self.uses_mrope: + positions = input_batch.positions + else: + positions = input_batch.mrope_positions with ( self.maybe_dummy_run_with_lora( self.lora_config, @@ -283,7 +299,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ): hidden_states = self.model( input_ids=input_batch.input_ids, - positions=input_batch.positions, + positions=positions, ) sample_hidden_states = hidden_states[input_batch.logits_indices] return hidden_states, sample_hidden_states @@ -393,8 +409,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_params=new_req_data.sampling_params, lora_request=new_req_data.lora_request, ) - req_index = self.req_states.req_id_to_index[req_id] + + # Pre-compute M-RoPE positions for prefill. + if self.uses_mrope: + self.mrope_states.init_prefill_mrope_positions( + req_index, + self.model, # type: ignore + new_req_data.prefill_token_ids, + mm_features=[], # TODO + ) + self.block_tables.append_block_ids( req_index, new_req_data.block_ids, overwrite=True ) @@ -411,6 +436,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.req_states.apply_staged_writes() self.block_tables.apply_staged_writes() + if self.uses_mrope: + self.mrope_states.apply_staged_writes() def prepare_inputs( self, @@ -511,6 +538,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) seq_lens = self.input_buffers.seq_lens[:num_reqs] + # Prepare M-RoPE positions. + if self.uses_mrope: + self.mrope_states.prepare_mrope_positions( + idx_mapping, + query_start_loc, + self.req_states.prefill_len.gpu, + self.req_states.num_computed_tokens.gpu, + self.input_buffers.mrope_positions, + ) + # Some input token ids are directly read from the last sampled tokens # and draft tokens. Also, get the logits indices to sample tokens from. logits_indices = combine_sampled_and_draft_tokens( @@ -546,6 +583,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding] + mrope_positions = self.input_buffers.mrope_positions[ + :, :num_tokens_after_padding + ] return InputBatch( req_ids=req_ids, num_reqs=num_reqs, @@ -561,6 +601,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): seq_lens=seq_lens, input_ids=input_ids, positions=positions, + mrope_positions=mrope_positions, attn_metadata=attn_metadata, logits_indices=logits_indices, cu_num_logits=cu_num_logits, @@ -889,6 +930,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: # Run PyTorch model in eager mode. # TODO(woosuk): Support piecewise CUDA graph. + if not self.uses_mrope: + positions = input_batch.positions + else: + positions = input_batch.mrope_positions with set_forward_context( input_batch.attn_metadata, self.vllm_config, @@ -898,7 +943,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ): hidden_states = self.model( input_ids=input_batch.input_ids, - positions=input_batch.positions, + positions=positions, ) self.execute_model_state = hidden_states, input_batch, sampling_metadata