[ModelRunner V2] Misc code simplification and cleanup (#33266)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,33 +27,27 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.sampler_output = sampler_output
|
||||
self.num_sampled_tokens = num_sampled_tokens
|
||||
self.copy_stream = copy_stream
|
||||
self.copy_event = copy_event
|
||||
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
self.copy_stream.wait_stream(default_stream)
|
||||
with torch.cuda.stream(copy_stream):
|
||||
copy_stream.wait_stream(default_stream)
|
||||
|
||||
self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids)
|
||||
self.logprobs_tensors: LogprobsTensors | None = None
|
||||
if sampler_output.logprobs_tensors is not None:
|
||||
self.logprobs_tensors: LogprobsTensors | None = (
|
||||
self.logprobs_tensors = (
|
||||
sampler_output.logprobs_tensors.to_cpu_nonblocking()
|
||||
)
|
||||
else:
|
||||
self.logprobs_tensors = None
|
||||
self.num_nans: np.ndarray | None = None
|
||||
if sampler_output.num_nans is not None:
|
||||
self.num_nans = async_copy_to_np(sampler_output.num_nans)
|
||||
else:
|
||||
self.num_nans = None
|
||||
self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens)
|
||||
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
|
||||
if self.model_runner_output.prompt_logprobs_dict:
|
||||
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
|
||||
if v is not None:
|
||||
self.prompt_logprobs_dict[k] = v.to_cpu_nonblocking()
|
||||
else:
|
||||
self.prompt_logprobs_dict[k] = None
|
||||
self.copy_event.record(self.copy_stream)
|
||||
self.prompt_logprobs_dict = {
|
||||
k: v.to_cpu_nonblocking() if v is not None else None
|
||||
for k, v in self.model_runner_output.prompt_logprobs_dict.items()
|
||||
}
|
||||
self.copy_event.record(copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
@@ -64,18 +57,15 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
# Going forward, we should keep the data structures as NumPy arrays
|
||||
# rather than Python lists.
|
||||
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
|
||||
num_reqs = len(sampled_token_ids)
|
||||
num_sampled_tokens = self.num_sampled_tokens_np.tolist()
|
||||
for i in range(num_reqs):
|
||||
del sampled_token_ids[i][num_sampled_tokens[i] :]
|
||||
num_sampled_tokens: list[int] = self.num_sampled_tokens_np.tolist()
|
||||
for token_ids, num_tokens in zip(sampled_token_ids, num_sampled_tokens):
|
||||
del token_ids[num_tokens:]
|
||||
self.model_runner_output.sampled_token_ids = sampled_token_ids
|
||||
|
||||
if self.num_nans is not None:
|
||||
num_nans = self.num_nans.tolist()
|
||||
self.model_runner_output.num_nans_in_logits = {
|
||||
req_id: num_nans[i]
|
||||
for i, req_id in enumerate(self.model_runner_output.req_ids)
|
||||
}
|
||||
self.model_runner_output.num_nans_in_logits = dict(
|
||||
zip(self.model_runner_output.req_ids, self.num_nans.tolist())
|
||||
)
|
||||
|
||||
if self.logprobs_tensors is not None:
|
||||
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
|
||||
@@ -83,16 +73,5 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
@contextmanager
|
||||
def async_barrier(event: torch.cuda.Event | None):
|
||||
if event is not None:
|
||||
event.synchronize()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if event is not None:
|
||||
event.record()
|
||||
|
||||
|
||||
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
|
||||
return x.to("cpu", non_blocking=True).numpy()
|
||||
|
||||
@@ -123,12 +123,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
self.output_copy_event = torch.cuda.Event()
|
||||
if self.use_async_scheduling:
|
||||
self.input_prep_event = torch.cuda.Event()
|
||||
self.structured_outputs_event = torch.cuda.Event()
|
||||
else:
|
||||
self.input_prep_event = None
|
||||
self.structured_outputs_event = None
|
||||
|
||||
if self.speculative_config is not None:
|
||||
self.do_spec_decode = True
|
||||
@@ -179,7 +173,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
@staticmethod
|
||||
def get_supported_tasks() -> tuple[str]:
|
||||
return ("generate",)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
@@ -194,9 +189,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
self.model,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
self.model, self.vllm_config, self.device
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
self.speculator.load_model(self.model)
|
||||
@@ -238,9 +231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
self.attn_backends, self.attn_metadata_builders = init_attn_backend(
|
||||
self.kv_cache_config,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
self.kv_cache_config, self.vllm_config, self.device
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
# HACK(woosuk)
|
||||
@@ -288,11 +279,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
skip_attn: bool = True,
|
||||
**kwargs,
|
||||
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Create a dummy scheduler output.
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
@@ -320,10 +307,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return hidden_states, sample_hidden_states
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> None:
|
||||
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
num_reqs = hidden_states.shape[0]
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
|
||||
@@ -337,8 +321,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens,
|
||||
skip_attn=True,
|
||||
self.max_num_tokens, skip_attn=True
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
if self.do_spec_decode:
|
||||
@@ -482,11 +465,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# Add new blocks for the existing requests.
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
req_new_block_ids = cached_reqs.new_block_ids[i]
|
||||
reqs = scheduler_output.scheduled_cached_reqs
|
||||
for req_new_block_ids, req_id in zip(reqs.new_block_ids, reqs.req_ids):
|
||||
if req_new_block_ids is not None:
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, req_new_block_ids, overwrite=False
|
||||
)
|
||||
@@ -517,7 +499,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
|
||||
|
||||
# Get the number of draft tokens for each request.
|
||||
if not scheduler_output.scheduled_spec_decode_tokens:
|
||||
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
if not draft_tokens:
|
||||
# No draft token scheduled (common case).
|
||||
total_num_draft_tokens = 0
|
||||
total_num_logits = num_reqs
|
||||
@@ -527,12 +510,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
else:
|
||||
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
num_draft_tokens = np.array(
|
||||
[
|
||||
len(draft_tokens[req_id]) if req_id in draft_tokens else 0
|
||||
for req_id in req_ids
|
||||
],
|
||||
[len(draft_tokens.get(req_id, ())) for req_id in req_ids],
|
||||
dtype=np.int32,
|
||||
)
|
||||
total_num_draft_tokens = int(num_draft_tokens.sum())
|
||||
@@ -544,11 +523,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
np.cumsum(num_logits, out=cu_num_logits_np[1:])
|
||||
cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
|
||||
|
||||
max_expand_len = self.num_speculative_steps + 1
|
||||
expanded_idx_mapping = expand_idx_mapping(
|
||||
idx_mapping,
|
||||
total_num_logits,
|
||||
cu_num_logits,
|
||||
max_expand_len=self.num_speculative_steps + 1,
|
||||
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
|
||||
)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
@@ -640,9 +617,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@@ -762,10 +738,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Update the number of computed prefill tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
computed_prefill = self.req_states.num_computed_prefill_tokens
|
||||
# TODO(woosuk): Simplify this.
|
||||
computed_prefill[idx_mapping_np] = np.minimum(
|
||||
computed_prefill[idx_mapping_np] + input_batch.num_scheduled_tokens,
|
||||
self.req_states.prefill_len.np[idx_mapping_np],
|
||||
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
|
||||
np.minimum(
|
||||
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -834,8 +809,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Common case.
|
||||
# Prepare all the inputs and copy to the input buffers.
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output,
|
||||
num_tokens_after_padding,
|
||||
scheduler_output, num_tokens_after_padding
|
||||
)
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
|
||||
@@ -107,6 +107,9 @@ class Worker(WorkerBase):
|
||||
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
if self.use_v2_model_runner:
|
||||
logger.info_once("Using V2 Model Runner", scope="global")
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user