[Model Runner V2] Move mrope_positions buffer to MRopeState (#32532)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-17 20:09:48 -08:00
committed by GitHub
parent 3055232ba0
commit 4147910f1e
4 changed files with 49 additions and 34 deletions

View File

@@ -75,16 +75,17 @@ class CudaGraphManager:
num_tokens: int,
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
if not self.uses_mrope:
positions = input_buffers.positions[:num_tokens]
else:
positions = input_buffers.mrope_positions[:, :num_tokens]
positions = input_buffers.positions[:num_tokens]
if self.uses_mrope:
assert mrope_positions is not None
positions = mrope_positions[:, :num_tokens]
attn_metadata = prepare_inputs_to_capture(
num_reqs,
num_tokens,
@@ -136,6 +137,7 @@ class CudaGraphManager:
self,
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
@@ -146,6 +148,7 @@ class CudaGraphManager:
self.capture_graph,
model=model,
input_buffers=input_buffers,
mrope_positions=mrope_positions,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,

View File

@@ -31,19 +31,6 @@ class InputBuffers:
)
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, max_num_tokens + 1), dtype=torch.int64, device=device
)
@dataclass
class InputBatch:
@@ -76,7 +63,7 @@ class InputBatch:
# [num_tokens_after_padding]
positions: torch.Tensor
# [3, num_tokens_after_padding]
mrope_positions: torch.Tensor
mrope_positions: torch.Tensor | None
# layer_name -> Metadata
attn_metadata: dict[str, Any]
@@ -124,8 +111,6 @@ class InputBatch:
input_ids = input_buffers.input_ids[:num_tokens].zero_()
positions = input_buffers.positions[:num_tokens].zero_()
input_buffers.mrope_positions.zero_()
mrope_positions = input_buffers.mrope_positions[:, :num_tokens]
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1
@@ -146,7 +131,7 @@ class InputBatch:
seq_lens=seq_lens,
input_ids=input_ids,
positions=positions,
mrope_positions=mrope_positions,
mrope_positions=None,
attn_metadata=None, # type: ignore
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,

View File

@@ -11,10 +11,12 @@ class MRopeState:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
max_model_len: int,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.max_model_len = max_model_len
self.device = device
@@ -28,6 +30,19 @@ class MRopeState:
)
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, max_num_tokens + 1), dtype=torch.int64, device=device
)
def init_prefill_mrope_positions(
self,
req_idx: int,
@@ -58,12 +73,11 @@ class MRopeState:
query_start_loc: torch.Tensor,
prefill_lens: torch.Tensor,
num_computed_tokens: torch.Tensor,
mrope_positions: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_prepare_mrope_positions_kernel[(num_reqs,)](
mrope_positions,
mrope_positions.stride(0),
self.mrope_positions,
self.mrope_positions.stride(0),
self.prefill_mrope_positions.gpu,
self.prefill_mrope_positions.gpu.stride(0),
self.max_model_len,

View File

@@ -99,6 +99,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.uses_mrope:
self.mrope_states = MRopeState(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
@@ -284,15 +285,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_buffers=self.input_buffers,
device=self.device,
)
if self.uses_mrope:
input_batch.mrope_positions = self.mrope_states.mrope_positions[
:, :num_tokens
]
if not skip_attn:
self.prepare_dummy_attn_metadata(input_batch)
dp_size = self.parallel_config.data_parallel_size
num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens)
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
if not self.uses_mrope:
positions = input_batch.positions
else:
positions = input_batch.positions
if self.uses_mrope:
positions = input_batch.mrope_positions
with (
self.maybe_dummy_run_with_lora(
@@ -371,9 +375,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
self.cudagraph_manager.capture(
model=self.model,
input_buffers=self.input_buffers,
mrope_positions=mrope_positions,
block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
@@ -566,7 +574,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
self.input_buffers.mrope_positions,
)
# Some input token ids are directly read from the last sampled tokens
@@ -604,9 +611,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
mrope_positions = self.input_buffers.mrope_positions[
:, :num_tokens_after_padding
]
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions[
:, :num_tokens_after_padding
]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
@@ -936,6 +945,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_buffers=self.input_buffers,
device=self.device,
)
if self.uses_mrope:
input_batch.mrope_positions = self.mrope_states.mrope_positions[
:, :num_tokens_after_padding
]
self.prepare_dummy_attn_metadata(input_batch)
# Run model.
@@ -949,9 +962,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
# Run PyTorch model in eager mode.
# TODO(woosuk): Support piecewise CUDA graph.
if not self.uses_mrope:
positions = input_batch.positions
else:
positions = input_batch.positions
if self.uses_mrope:
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions
with set_forward_context(
input_batch.attn_metadata,