[Model Runner V2] Add support for M-RoPE (#32143)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -25,10 +25,12 @@ class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
@@ -79,7 +81,10 @@ class CudaGraphManager:
|
||||
) -> None:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if not self.uses_mrope:
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
else:
|
||||
positions = input_buffers.mrope_positions[:, :num_tokens]
|
||||
attn_metadata = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
|
||||
@@ -31,6 +31,19 @@ class InputBuffers:
|
||||
)
|
||||
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
||||
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputBatch:
|
||||
@@ -62,6 +75,8 @@ class InputBatch:
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
@@ -107,8 +122,11 @@ class InputBatch:
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||
positions = input_buffers.positions[:num_tokens].zero_()
|
||||
input_buffers.mrope_positions.zero_()
|
||||
mrope_positions = input_buffers.mrope_positions[:, :num_tokens]
|
||||
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
@@ -128,6 +146,7 @@ class InputBatch:
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
attn_metadata=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
|
||||
0
vllm/v1/worker/gpu/mm/__init__.py
Normal file
0
vllm/v1/worker/gpu/mm/__init__.py
Normal file
127
vllm/v1/worker/gpu/mm/mrope_utils.py
Normal file
127
vllm/v1/worker/gpu/mm/mrope_utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMRoPE
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class MRopeState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# wasting a lot of CPU memory.
|
||||
self.prefill_mrope_positions = StagedWriteTensor(
|
||||
(max_num_reqs, 3 * max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
def init_prefill_mrope_positions(
|
||||
self,
|
||||
req_idx: int,
|
||||
mrope_model: SupportsMRoPE,
|
||||
prefill_token_ids: list[int],
|
||||
mm_features: list,
|
||||
) -> None:
|
||||
prefill_mrope_positions, prefill_mrope_delta = (
|
||||
mrope_model.get_mrope_input_positions(
|
||||
prefill_token_ids,
|
||||
mm_features,
|
||||
)
|
||||
)
|
||||
for i in range(3):
|
||||
pos = prefill_mrope_positions[i].tolist()
|
||||
self.prefill_mrope_positions.stage_write(
|
||||
req_idx, i * self.max_model_len, pos
|
||||
)
|
||||
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prefill_mrope_positions.apply_write()
|
||||
self.prefill_mrope_delta.copy_to_uva()
|
||||
|
||||
def prepare_mrope_positions(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
prefill_lens: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
mrope_positions: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prepare_mrope_positions_kernel[(num_reqs,)](
|
||||
mrope_positions,
|
||||
mrope_positions.stride(0),
|
||||
self.prefill_mrope_positions.gpu,
|
||||
self.prefill_mrope_positions.gpu.stride(0),
|
||||
self.max_model_len,
|
||||
self.prefill_mrope_delta.gpu,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
prefill_lens,
|
||||
num_computed_tokens,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_mrope_positions_kernel(
|
||||
mrope_positions_ptr,
|
||||
mrope_positions_stride,
|
||||
prefill_mrope_positions_ptr,
|
||||
prefill_mrope_positions_stride0,
|
||||
prefill_mrope_positions_stride1,
|
||||
prefill_mrope_delta_ptr,
|
||||
idx_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
prefill_lens_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
is_prefill = num_computed < prefill_len
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx)
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
orig_pos = num_computed + block
|
||||
|
||||
for j in tl.static_range(3):
|
||||
if is_prefill:
|
||||
# Read from pre-computed M-RoPE positions.
|
||||
pos = tl.load(
|
||||
prefill_mrope_positions_ptr
|
||||
+ req_state_idx * prefill_mrope_positions_stride0
|
||||
+ j * prefill_mrope_positions_stride1
|
||||
+ orig_pos,
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
# Apply M-RoPE delta.
|
||||
pos = orig_pos + mrope_delta
|
||||
tl.store(
|
||||
mrope_positions_ptr + j * mrope_positions_stride + query_start + block,
|
||||
pos,
|
||||
mask=mask,
|
||||
)
|
||||
@@ -47,6 +47,7 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
@@ -94,6 +95,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
||||
|
||||
# Multimodal
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_states = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
self.output_copy_event = torch.cuda.Event()
|
||||
@@ -132,7 +142,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
self.vllm_config, self.uses_mrope, self.device
|
||||
)
|
||||
# Structured outputs worker.
|
||||
self.structured_outputs_worker = StructuredOutputsWorker(
|
||||
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
|
||||
@@ -268,6 +280,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dp_size = self.parallel_config.data_parallel_size
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens)
|
||||
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
|
||||
if not self.uses_mrope:
|
||||
positions = input_batch.positions
|
||||
else:
|
||||
positions = input_batch.mrope_positions
|
||||
with (
|
||||
self.maybe_dummy_run_with_lora(
|
||||
self.lora_config,
|
||||
@@ -283,7 +299,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
positions=positions,
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@@ -393,8 +409,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampling_params=new_req_data.sampling_params,
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=[], # TODO
|
||||
)
|
||||
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, new_req_data.block_ids, overwrite=True
|
||||
)
|
||||
@@ -411,6 +436,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
self.req_states.apply_staged_writes()
|
||||
self.block_tables.apply_staged_writes()
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
@@ -511,6 +538,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.prepare_mrope_positions(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
self.input_buffers.mrope_positions,
|
||||
)
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
logits_indices = combine_sampled_and_draft_tokens(
|
||||
@@ -546,6 +583,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
mrope_positions = self.input_buffers.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@@ -561,6 +601,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
@@ -889,6 +930,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
# Run PyTorch model in eager mode.
|
||||
# TODO(woosuk): Support piecewise CUDA graph.
|
||||
if not self.uses_mrope:
|
||||
positions = input_batch.positions
|
||||
else:
|
||||
positions = input_batch.mrope_positions
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
@@ -898,7 +943,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids,
|
||||
positions=input_batch.positions,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
self.execute_model_state = hidden_states, input_batch, sampling_metadata
|
||||
|
||||
Reference in New Issue
Block a user