[Model Runner V2] Move mrope_positions buffer to MRopeState (#32532)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -75,16 +75,17 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
if not self.uses_mrope:
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
else:
|
||||
positions = input_buffers.mrope_positions[:, :num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if self.uses_mrope:
|
||||
assert mrope_positions is not None
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
attn_metadata = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
@@ -136,6 +137,7 @@ class CudaGraphManager:
|
||||
self,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -146,6 +148,7 @@ class CudaGraphManager:
|
||||
self.capture_graph,
|
||||
model=model,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
block_tables=block_tables,
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
kv_cache_config=kv_cache_config,
|
||||
|
||||
@@ -31,19 +31,6 @@ class InputBuffers:
|
||||
)
|
||||
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
||||
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputBatch:
|
||||
@@ -76,7 +63,7 @@ class InputBatch:
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor
|
||||
mrope_positions: torch.Tensor | None
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
@@ -124,8 +111,6 @@ class InputBatch:
|
||||
|
||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||
positions = input_buffers.positions[:num_tokens].zero_()
|
||||
input_buffers.mrope_positions.zero_()
|
||||
mrope_positions = input_buffers.mrope_positions[:, :num_tokens]
|
||||
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
@@ -146,7 +131,7 @@ class InputBatch:
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
mrope_positions=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
|
||||
@@ -11,10 +11,12 @@ class MRopeState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
|
||||
@@ -28,6 +30,19 @@ class MRopeState:
|
||||
)
|
||||
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
def init_prefill_mrope_positions(
|
||||
self,
|
||||
req_idx: int,
|
||||
@@ -58,12 +73,11 @@ class MRopeState:
|
||||
query_start_loc: torch.Tensor,
|
||||
prefill_lens: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
mrope_positions: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prepare_mrope_positions_kernel[(num_reqs,)](
|
||||
mrope_positions,
|
||||
mrope_positions.stride(0),
|
||||
self.mrope_positions,
|
||||
self.mrope_positions.stride(0),
|
||||
self.prefill_mrope_positions.gpu,
|
||||
self.prefill_mrope_positions.gpu.stride(0),
|
||||
self.max_model_len,
|
||||
|
||||
@@ -99,6 +99,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.uses_mrope:
|
||||
self.mrope_states = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -284,15 +285,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
if self.uses_mrope:
|
||||
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens
|
||||
]
|
||||
if not skip_attn:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
|
||||
dp_size = self.parallel_config.data_parallel_size
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens)
|
||||
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
|
||||
if not self.uses_mrope:
|
||||
positions = input_batch.positions
|
||||
else:
|
||||
positions = input_batch.positions
|
||||
if self.uses_mrope:
|
||||
positions = input_batch.mrope_positions
|
||||
with (
|
||||
self.maybe_dummy_run_with_lora(
|
||||
@@ -371,9 +375,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
input_buffers=self.input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
block_tables=self.block_tables,
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
@@ -566,7 +574,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
self.input_buffers.mrope_positions,
|
||||
)
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
@@ -604,9 +611,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
mrope_positions = self.input_buffers.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@@ -936,6 +945,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
if self.uses_mrope:
|
||||
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
|
||||
# Run model.
|
||||
@@ -949,9 +962,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
# Run PyTorch model in eager mode.
|
||||
# TODO(woosuk): Support piecewise CUDA graph.
|
||||
if not self.uses_mrope:
|
||||
positions = input_batch.positions
|
||||
else:
|
||||
positions = input_batch.positions
|
||||
if self.uses_mrope:
|
||||
assert input_batch.mrope_positions is not None
|
||||
positions = input_batch.mrope_positions
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
|
||||
Reference in New Issue
Block a user