[Model Runner V2] Move mrope_positions buffer to MRopeState (#32532)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -75,16 +75,17 @@ class CudaGraphManager:
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
input_buffers: InputBuffers,
|
input_buffers: InputBuffers,
|
||||||
|
mrope_positions: torch.Tensor | None,
|
||||||
block_tables: BlockTables,
|
block_tables: BlockTables,
|
||||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||||
input_ids = input_buffers.input_ids[:num_tokens]
|
input_ids = input_buffers.input_ids[:num_tokens]
|
||||||
if not self.uses_mrope:
|
|
||||||
positions = input_buffers.positions[:num_tokens]
|
positions = input_buffers.positions[:num_tokens]
|
||||||
else:
|
if self.uses_mrope:
|
||||||
positions = input_buffers.mrope_positions[:, :num_tokens]
|
assert mrope_positions is not None
|
||||||
|
positions = mrope_positions[:, :num_tokens]
|
||||||
attn_metadata = prepare_inputs_to_capture(
|
attn_metadata = prepare_inputs_to_capture(
|
||||||
num_reqs,
|
num_reqs,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -136,6 +137,7 @@ class CudaGraphManager:
|
|||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
input_buffers: InputBuffers,
|
input_buffers: InputBuffers,
|
||||||
|
mrope_positions: torch.Tensor | None,
|
||||||
block_tables: BlockTables,
|
block_tables: BlockTables,
|
||||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
@@ -146,6 +148,7 @@ class CudaGraphManager:
|
|||||||
self.capture_graph,
|
self.capture_graph,
|
||||||
model=model,
|
model=model,
|
||||||
input_buffers=input_buffers,
|
input_buffers=input_buffers,
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
attn_metadata_builders=attn_metadata_builders,
|
attn_metadata_builders=attn_metadata_builders,
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
|
|||||||
@@ -31,19 +31,6 @@ class InputBuffers:
|
|||||||
)
|
)
|
||||||
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
|
||||||
# position on purpose to make it non-contiguous so that it can work
|
|
||||||
# with torch compile.
|
|
||||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
|
||||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
|
||||||
# the modality of inputs. For text-only inputs, each dimension has
|
|
||||||
# identical position IDs, making M-RoPE functionally equivalent to
|
|
||||||
# 1D-RoPE.
|
|
||||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
|
||||||
self.mrope_positions = torch.zeros(
|
|
||||||
(3, max_num_tokens + 1), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputBatch:
|
class InputBatch:
|
||||||
@@ -76,7 +63,7 @@ class InputBatch:
|
|||||||
# [num_tokens_after_padding]
|
# [num_tokens_after_padding]
|
||||||
positions: torch.Tensor
|
positions: torch.Tensor
|
||||||
# [3, num_tokens_after_padding]
|
# [3, num_tokens_after_padding]
|
||||||
mrope_positions: torch.Tensor
|
mrope_positions: torch.Tensor | None
|
||||||
|
|
||||||
# layer_name -> Metadata
|
# layer_name -> Metadata
|
||||||
attn_metadata: dict[str, Any]
|
attn_metadata: dict[str, Any]
|
||||||
@@ -124,8 +111,6 @@ class InputBatch:
|
|||||||
|
|
||||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||||
positions = input_buffers.positions[:num_tokens].zero_()
|
positions = input_buffers.positions[:num_tokens].zero_()
|
||||||
input_buffers.mrope_positions.zero_()
|
|
||||||
mrope_positions = input_buffers.mrope_positions[:, :num_tokens]
|
|
||||||
|
|
||||||
# attn_metadata = defaultdict(lambda: None)
|
# attn_metadata = defaultdict(lambda: None)
|
||||||
logits_indices = query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
@@ -146,7 +131,7 @@ class InputBatch:
|
|||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
mrope_positions=mrope_positions,
|
mrope_positions=None,
|
||||||
attn_metadata=None, # type: ignore
|
attn_metadata=None, # type: ignore
|
||||||
logits_indices=logits_indices,
|
logits_indices=logits_indices,
|
||||||
cu_num_logits=cu_num_logits,
|
cu_num_logits=cu_num_logits,
|
||||||
|
|||||||
@@ -11,10 +11,12 @@ class MRopeState:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_num_reqs: int,
|
max_num_reqs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.max_num_reqs = max_num_reqs
|
self.max_num_reqs = max_num_reqs
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@@ -28,6 +30,19 @@ class MRopeState:
|
|||||||
)
|
)
|
||||||
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||||
|
|
||||||
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||||
|
# position on purpose to make it non-contiguous so that it can work
|
||||||
|
# with torch compile.
|
||||||
|
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||||
|
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||||
|
# the modality of inputs. For text-only inputs, each dimension has
|
||||||
|
# identical position IDs, making M-RoPE functionally equivalent to
|
||||||
|
# 1D-RoPE.
|
||||||
|
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||||
|
self.mrope_positions = torch.zeros(
|
||||||
|
(3, max_num_tokens + 1), dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def init_prefill_mrope_positions(
|
def init_prefill_mrope_positions(
|
||||||
self,
|
self,
|
||||||
req_idx: int,
|
req_idx: int,
|
||||||
@@ -58,12 +73,11 @@ class MRopeState:
|
|||||||
query_start_loc: torch.Tensor,
|
query_start_loc: torch.Tensor,
|
||||||
prefill_lens: torch.Tensor,
|
prefill_lens: torch.Tensor,
|
||||||
num_computed_tokens: torch.Tensor,
|
num_computed_tokens: torch.Tensor,
|
||||||
mrope_positions: torch.Tensor,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
num_reqs = idx_mapping.shape[0]
|
num_reqs = idx_mapping.shape[0]
|
||||||
_prepare_mrope_positions_kernel[(num_reqs,)](
|
_prepare_mrope_positions_kernel[(num_reqs,)](
|
||||||
mrope_positions,
|
self.mrope_positions,
|
||||||
mrope_positions.stride(0),
|
self.mrope_positions.stride(0),
|
||||||
self.prefill_mrope_positions.gpu,
|
self.prefill_mrope_positions.gpu,
|
||||||
self.prefill_mrope_positions.gpu.stride(0),
|
self.prefill_mrope_positions.gpu.stride(0),
|
||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
self.mrope_states = MRopeState(
|
self.mrope_states = MRopeState(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
|
max_num_tokens=self.max_num_tokens,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
@@ -284,15 +285,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
input_buffers=self.input_buffers,
|
input_buffers=self.input_buffers,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
if self.uses_mrope:
|
||||||
|
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||||
|
:, :num_tokens
|
||||||
|
]
|
||||||
if not skip_attn:
|
if not skip_attn:
|
||||||
self.prepare_dummy_attn_metadata(input_batch)
|
self.prepare_dummy_attn_metadata(input_batch)
|
||||||
|
|
||||||
dp_size = self.parallel_config.data_parallel_size
|
dp_size = self.parallel_config.data_parallel_size
|
||||||
num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens)
|
num_tokens_across_dp = make_num_tokens_across_dp(dp_size, num_tokens)
|
||||||
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
|
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
|
||||||
if not self.uses_mrope:
|
|
||||||
positions = input_batch.positions
|
positions = input_batch.positions
|
||||||
else:
|
if self.uses_mrope:
|
||||||
positions = input_batch.mrope_positions
|
positions = input_batch.mrope_positions
|
||||||
with (
|
with (
|
||||||
self.maybe_dummy_run_with_lora(
|
self.maybe_dummy_run_with_lora(
|
||||||
@@ -371,9 +375,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||||
|
mrope_positions = None
|
||||||
|
if self.uses_mrope:
|
||||||
|
mrope_positions = self.mrope_states.mrope_positions
|
||||||
self.cudagraph_manager.capture(
|
self.cudagraph_manager.capture(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
input_buffers=self.input_buffers,
|
input_buffers=self.input_buffers,
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
block_tables=self.block_tables,
|
block_tables=self.block_tables,
|
||||||
attn_metadata_builders=self.attn_metadata_builders,
|
attn_metadata_builders=self.attn_metadata_builders,
|
||||||
kv_cache_config=self.kv_cache_config,
|
kv_cache_config=self.kv_cache_config,
|
||||||
@@ -566,7 +574,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
query_start_loc,
|
query_start_loc,
|
||||||
self.req_states.prefill_len.gpu,
|
self.req_states.prefill_len.gpu,
|
||||||
self.req_states.num_computed_tokens.gpu,
|
self.req_states.num_computed_tokens.gpu,
|
||||||
self.input_buffers.mrope_positions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Some input token ids are directly read from the last sampled tokens
|
# Some input token ids are directly read from the last sampled tokens
|
||||||
@@ -604,7 +611,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||||
mrope_positions = self.input_buffers.mrope_positions[
|
mrope_positions = None
|
||||||
|
if self.uses_mrope:
|
||||||
|
mrope_positions = self.mrope_states.mrope_positions[
|
||||||
:, :num_tokens_after_padding
|
:, :num_tokens_after_padding
|
||||||
]
|
]
|
||||||
return InputBatch(
|
return InputBatch(
|
||||||
@@ -936,6 +945,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
input_buffers=self.input_buffers,
|
input_buffers=self.input_buffers,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
if self.uses_mrope:
|
||||||
|
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||||
|
:, :num_tokens_after_padding
|
||||||
|
]
|
||||||
self.prepare_dummy_attn_metadata(input_batch)
|
self.prepare_dummy_attn_metadata(input_batch)
|
||||||
|
|
||||||
# Run model.
|
# Run model.
|
||||||
@@ -949,9 +962,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
# Run PyTorch model in eager mode.
|
# Run PyTorch model in eager mode.
|
||||||
# TODO(woosuk): Support piecewise CUDA graph.
|
# TODO(woosuk): Support piecewise CUDA graph.
|
||||||
if not self.uses_mrope:
|
|
||||||
positions = input_batch.positions
|
positions = input_batch.positions
|
||||||
else:
|
if self.uses_mrope:
|
||||||
|
assert input_batch.mrope_positions is not None
|
||||||
positions = input_batch.mrope_positions
|
positions = input_batch.mrope_positions
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
input_batch.attn_metadata,
|
input_batch.attn_metadata,
|
||||||
|
|||||||
Reference in New Issue
Block a user