[Model Runner V2] Add support for M-RoPE (#32143)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-12 13:37:43 -08:00
committed by GitHub
parent dec28688c5
commit 0a7dd23754
5 changed files with 203 additions and 7 deletions

View File

@@ -25,10 +25,12 @@ class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
uses_mrope: bool,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
@@ -79,7 +81,10 @@ class CudaGraphManager:
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
if not self.uses_mrope:
positions = input_buffers.positions[:num_tokens]
else:
positions = input_buffers.mrope_positions[:, :num_tokens]
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,

View File

@@ -31,6 +31,19 @@ class InputBuffers:
)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, max_num_tokens + 1), dtype=torch.int64, device=device
)
@dataclass
class InputBatch:
@@ -62,6 +75,8 @@ class InputBatch:
input_ids: torch.Tensor
# [num_tokens_after_padding]
positions: torch.Tensor
# [3, num_tokens_after_padding]
mrope_positions: torch.Tensor
# layer_name -> Metadata
attn_metadata: dict[str, Any]
@@ -107,8 +122,11 @@ class InputBatch:
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
input_ids = input_buffers.input_ids[:num_tokens].zero_()
positions = input_buffers.positions[:num_tokens].zero_()
input_buffers.mrope_positions.zero_()
mrope_positions = input_buffers.mrope_positions[:, :num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
@@ -128,6 +146,7 @@ class InputBatch:
seq_lens=seq_lens,
input_ids=input_ids,
positions=positions,
mrope_positions=mrope_positions,
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,

View File

View File

@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.models.interfaces import SupportsMRoPE
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class MRopeState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.device = device
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory.
self.prefill_mrope_positions = StagedWriteTensor(
(max_num_reqs, 3 * max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
)
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
def init_prefill_mrope_positions(
self,
req_idx: int,
mrope_model: SupportsMRoPE,
prefill_token_ids: list[int],
mm_features: list,
) -> None:
prefill_mrope_positions, prefill_mrope_delta = (
mrope_model.get_mrope_input_positions(
prefill_token_ids,
mm_features,
)
)
for i in range(3):
pos = prefill_mrope_positions[i].tolist()
self.prefill_mrope_positions.stage_write(
req_idx, i * self.max_model_len, pos
)
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
def apply_staged_writes(self) -> None:
self.prefill_mrope_positions.apply_write()
self.prefill_mrope_delta.copy_to_uva()
def prepare_mrope_positions(
self,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_lens: torch.Tensor,
num_computed_tokens: torch.Tensor,
mrope_positions: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_mrope_positions_kernel[(num_reqs,)](
mrope_positions,
mrope_positions.stride(0),
self.prefill_mrope_positions.gpu,
self.prefill_mrope_positions.gpu.stride(0),
self.max_model_len,
self.prefill_mrope_delta.gpu,
idx_mapping,
query_start_loc,
prefill_lens,
num_computed_tokens,
BLOCK_SIZE=1024,
)
@triton.jit
def _prepare_mrope_positions_kernel(
mrope_positions_ptr,
mrope_positions_stride,
prefill_mrope_positions_ptr,
prefill_mrope_positions_stride0,
prefill_mrope_positions_stride1,
prefill_mrope_delta_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
is_prefill = num_computed < prefill_len
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx)
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
orig_pos = num_computed + block
for j in tl.static_range(3):
if is_prefill:
# Read from pre-computed M-RoPE positions.
pos = tl.load(
prefill_mrope_positions_ptr
+ req_state_idx * prefill_mrope_positions_stride0
+ j * prefill_mrope_positions_stride1
+ orig_pos,
mask=mask,
)
else:
# Apply M-RoPE delta.
pos = orig_pos + mrope_delta
tl.store(
mrope_positions_ptr + j * mrope_positions_stride + query_start + block,
pos,
mask=mask,
)

View File

@@ -47,6 +47,7 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_pos_seq_lens,
prepare_prefill_inputs,
)
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.output import SamplerOutput
@@ -94,6 +95,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
# Multimodal
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_states = MRopeState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
device=self.device,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device)
self.output_copy_event = torch.cuda.Event()
@@ -132,7 +142,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
self.cudagraph_manager = CudaGraphManager(
self.vllm_config, self.uses_mrope, self.device
)
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
@@ -268,6 +280,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dp_size = self.parallel_config.data_parallel_size
num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens)
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
if not self.uses_mrope:
positions = input_batch.positions
else:
positions = input_batch.mrope_positions
with (
self.maybe_dummy_run_with_lora(
self.lora_config,
@@ -283,7 +299,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
positions=positions,
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@@ -393,8 +409,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_params=new_req_data.sampling_params,
lora_request=new_req_data.lora_request,
)
req_index = self.req_states.req_id_to_index[req_id]
# Pre-compute M-RoPE positions for prefill.
if self.uses_mrope:
self.mrope_states.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=[], # TODO
)
self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True
)
@@ -411,6 +436,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.apply_staged_writes()
self.block_tables.apply_staged_writes()
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
def prepare_inputs(
self,
@@ -511,6 +538,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
# Prepare M-RoPE positions.
if self.uses_mrope:
self.mrope_states.prepare_mrope_positions(
idx_mapping,
query_start_loc,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
self.input_buffers.mrope_positions,
)
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens(
@@ -546,6 +583,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
mrope_positions = self.input_buffers.mrope_positions[
:, :num_tokens_after_padding
]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
@@ -561,6 +601,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
seq_lens=seq_lens,
input_ids=input_ids,
positions=positions,
mrope_positions=mrope_positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
@@ -889,6 +930,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
# Run PyTorch model in eager mode.
# TODO(woosuk): Support piecewise CUDA graph.
if not self.uses_mrope:
positions = input_batch.positions
else:
positions = input_batch.mrope_positions
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
@@ -898,7 +943,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
):
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=input_batch.positions,
positions=positions,
)
self.execute_model_state = hidden_states, input_batch, sampling_metadata