[Model Runner V2] Add Support for XD-RoPE (#36817)
Signed-off-by: Santino Ramos <elsantinoramos@gmail.com>
This commit is contained in:
@@ -320,6 +320,9 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
# TODO: Pass intermediate_tensors for PP CUDA graph
|
||||
# support (https://github.com/vllm-project/vllm/pull/35162).
|
||||
"intermediate_tensors": None,
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMRoPE
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class MRopeState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# wasting a lot of CPU memory.
|
||||
self.prefill_mrope_positions = StagedWriteTensor(
|
||||
(max_num_reqs * 3, max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
def init_prefill_mrope_positions(
|
||||
self,
|
||||
req_idx: int,
|
||||
mrope_model: SupportsMRoPE,
|
||||
prefill_token_ids: list[int],
|
||||
mm_features: list,
|
||||
) -> None:
|
||||
prefill_mrope_positions, prefill_mrope_delta = (
|
||||
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
|
||||
)
|
||||
for i in range(3):
|
||||
pos = prefill_mrope_positions[i].tolist()
|
||||
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
|
||||
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prefill_mrope_positions.apply_write()
|
||||
self.prefill_mrope_delta.copy_to_uva()
|
||||
|
||||
def prepare_mrope_positions(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
prefill_lens: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prepare_mrope_positions_kernel[(num_reqs,)](
|
||||
self.mrope_positions,
|
||||
self.mrope_positions.stride(0),
|
||||
self.prefill_mrope_positions.gpu,
|
||||
3 * self.max_model_len,
|
||||
self.max_model_len,
|
||||
self.prefill_mrope_delta.gpu,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
prefill_lens,
|
||||
num_computed_tokens,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_mrope_positions_kernel(
|
||||
mrope_positions_ptr,
|
||||
mrope_positions_stride,
|
||||
prefill_mrope_positions_ptr,
|
||||
prefill_mrope_positions_stride0,
|
||||
prefill_mrope_positions_stride1,
|
||||
prefill_mrope_delta_ptr,
|
||||
idx_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
prefill_lens_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
is_prefill = num_computed < prefill_len
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx)
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
orig_pos = num_computed + block
|
||||
|
||||
for j in tl.static_range(3):
|
||||
if is_prefill:
|
||||
# Read from pre-computed M-RoPE positions.
|
||||
pos = tl.load(
|
||||
prefill_mrope_positions_ptr
|
||||
+ req_state_idx * prefill_mrope_positions_stride0
|
||||
+ j * prefill_mrope_positions_stride1
|
||||
+ orig_pos,
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
# Apply M-RoPE delta.
|
||||
pos = orig_pos + mrope_delta
|
||||
tl.store(
|
||||
mrope_positions_ptr + j * mrope_positions_stride + query_start + block,
|
||||
pos,
|
||||
mask=mask,
|
||||
)
|
||||
197
vllm/v1/worker/gpu/mm/rope.py
Normal file
197
vllm/v1/worker/gpu/mm/rope.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsXDRoPE
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class RopeState:
|
||||
"""Unified state for multi-dimensional RoPE variants (M-RoPE, XD-RoPE).
|
||||
|
||||
M-RoPE: 3 dims, uses position delta for decode.
|
||||
XD-RoPE: 3 or 4 dims, delta is 0 (decode uses orig_pos for all dims).
|
||||
|
||||
NOTE: `positions` is implemented with one additional dummy position on
|
||||
purpose to make it non-contiguous so that it can work with torch compile.
|
||||
See detailed explanation in
|
||||
https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
|
||||
NOTE: When M-RoPE is enabled, position ids are 3D regardless of the
|
||||
modality of inputs. For text-only inputs, each dimension has identical
|
||||
position IDs, making M-RoPE functionally equivalent to 1D-RoPE.
|
||||
See page 5 of https://arxiv.org/abs/2409.12191
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_dims: int,
|
||||
has_delta: bool,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.num_dims = num_dims
|
||||
self.has_delta = has_delta
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# wasting a lot of CPU memory.
|
||||
self.prefill_positions = StagedWriteTensor(
|
||||
(max_num_reqs * num_dims, max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
self.positions = torch.zeros(
|
||||
(num_dims, max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Delta is non-zero for M-RoPE, always 0 for XD-RoPE.
|
||||
self.prefill_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
def init_prefill_positions(
|
||||
self,
|
||||
req_idx: int,
|
||||
model: nn.Module,
|
||||
prefill_token_ids: list[int],
|
||||
mm_features: list,
|
||||
) -> None:
|
||||
if self.has_delta:
|
||||
mrope_model = cast(SupportsMRoPE, model)
|
||||
prefill_positions, delta = mrope_model.get_mrope_input_positions(
|
||||
prefill_token_ids, mm_features
|
||||
)
|
||||
self.prefill_delta.np[req_idx] = delta
|
||||
else:
|
||||
xdrope_model = cast(SupportsXDRoPE, model)
|
||||
prefill_positions = xdrope_model.get_xdrope_input_positions(
|
||||
prefill_token_ids, mm_features
|
||||
)
|
||||
|
||||
for i in range(self.num_dims):
|
||||
pos = prefill_positions[i].tolist()
|
||||
self.prefill_positions.stage_write(self.num_dims * req_idx + i, 0, pos)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prefill_positions.apply_write()
|
||||
if self.has_delta:
|
||||
self.prefill_delta.copy_to_uva()
|
||||
|
||||
def get_positions(self, num_tokens: int) -> torch.Tensor:
|
||||
return self.positions[:, :num_tokens]
|
||||
|
||||
def prepare_positions(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
prefill_lens: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prepare_rope_positions_kernel[(num_reqs,)](
|
||||
self.positions,
|
||||
self.positions.stride(0),
|
||||
self.prefill_positions.gpu,
|
||||
self.num_dims * self.max_model_len,
|
||||
self.max_model_len,
|
||||
self.prefill_delta.gpu,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
prefill_lens,
|
||||
num_computed_tokens,
|
||||
BLOCK_SIZE=1024,
|
||||
NUM_DIMS=self.num_dims,
|
||||
)
|
||||
|
||||
|
||||
def get_rope_state(
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
) -> RopeState | None:
|
||||
"""Create a RopeState if the model uses multi-dimensional RoPE."""
|
||||
if model_config.uses_mrope:
|
||||
assert isinstance(model, SupportsMRoPE)
|
||||
return RopeState(
|
||||
num_dims=3,
|
||||
has_delta=True,
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_model_len=max_model_len,
|
||||
device=device,
|
||||
)
|
||||
if model_config.uses_xdrope_dim > 0:
|
||||
assert isinstance(model, SupportsXDRoPE)
|
||||
return RopeState(
|
||||
num_dims=model_config.uses_xdrope_dim,
|
||||
has_delta=False,
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_model_len=max_model_len,
|
||||
device=device,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_rope_positions_kernel(
|
||||
positions_ptr,
|
||||
positions_stride,
|
||||
prefill_positions_ptr,
|
||||
prefill_positions_stride0,
|
||||
prefill_positions_stride1,
|
||||
prefill_delta_ptr,
|
||||
idx_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
prefill_lens_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
NUM_DIMS: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
is_prefill = num_computed < prefill_len
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
delta = tl.load(prefill_delta_ptr + req_state_idx)
|
||||
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
orig_pos = num_computed + block
|
||||
|
||||
for j in tl.static_range(NUM_DIMS):
|
||||
if is_prefill:
|
||||
pos = tl.load(
|
||||
prefill_positions_ptr
|
||||
+ req_state_idx * prefill_positions_stride0
|
||||
+ j * prefill_positions_stride1
|
||||
+ orig_pos,
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
pos = orig_pos + delta
|
||||
tl.store(
|
||||
positions_ptr + j * positions_stride + query_start + block,
|
||||
pos,
|
||||
mask=mask,
|
||||
)
|
||||
@@ -992,6 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
"input_ids": input_batch.input_ids,
|
||||
"positions": input_batch.positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
# NOTE: Values returned by `prepare_inputs` will override the default
|
||||
# values above.
|
||||
**self.model_state.prepare_inputs(input_batch, self.req_states),
|
||||
@@ -1000,7 +1001,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Update for non-first PP ranks.
|
||||
model_inputs["input_ids"] = None
|
||||
model_inputs["inputs_embeds"] = None
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
# Run model.
|
||||
if batch_desc.cg_mode == CUDAGraphMode.FULL:
|
||||
|
||||
@@ -13,7 +13,7 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.mm.rope import get_rope_state
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
@@ -52,29 +52,28 @@ class DefaultModelState(ModelState):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_state = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
self.rope_state = get_rope_state(
|
||||
self.model_config,
|
||||
model,
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
if self.uses_mrope:
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
if self.rope_state is not None:
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
self.mrope_state.init_prefill_mrope_positions(
|
||||
self.rope_state.init_prefill_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
self.model,
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
if self.uses_mrope:
|
||||
self.mrope_state.apply_staged_writes()
|
||||
if self.rope_state is not None:
|
||||
self.rope_state.apply_staged_writes()
|
||||
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
@@ -109,31 +108,26 @@ class DefaultModelState(ModelState):
|
||||
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, Any]:
|
||||
if not self.uses_mrope:
|
||||
# Common case (1D positions).
|
||||
return {}
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
if self.rope_state is None:
|
||||
return {} # Common case (1D positions).
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
self.mrope_state.prepare_mrope_positions(
|
||||
self.rope_state.prepare_positions(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
req_states.prefill_len.gpu,
|
||||
req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
mrope_positions = self.mrope_state.mrope_positions[
|
||||
:, : input_batch.num_tokens_after_padding
|
||||
]
|
||||
return {"positions": mrope_positions}
|
||||
positions = self.rope_state.get_positions(input_batch.num_tokens_after_padding)
|
||||
return {"positions": positions}
|
||||
|
||||
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
|
||||
model_inputs = {}
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
|
||||
model_inputs["positions"] = mrope_positions
|
||||
if self.rope_state is not None:
|
||||
model_inputs["positions"] = self.rope_state.get_positions(num_tokens)
|
||||
return model_inputs
|
||||
|
||||
def prepare_attn(
|
||||
|
||||
Reference in New Issue
Block a user