[Model Runner V2] Minor code cleanup (#29570)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -35,7 +35,10 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
|
||||
from vllm.v1.worker.gpu.dp_utils import (
|
||||
get_batch_metadata_across_dp,
|
||||
make_num_tokens_across_dp,
|
||||
)
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
@@ -255,12 +258,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if not skip_attn:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
|
||||
if self.dp_size == 1:
|
||||
num_tokens_across_dp: torch.Tensor | None = None
|
||||
else:
|
||||
num_tokens_across_dp = torch.full(
|
||||
(self.dp_size,), num_tokens, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
|
||||
with (
|
||||
self.maybe_dummy_run_with_lora(
|
||||
@@ -816,7 +814,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.req_states.last_sampled_tokens,
|
||||
next_prefill_tokens,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
return draft_tokens
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
@@ -1006,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
_ = self.propose_draft(
|
||||
draft_tokens = self.propose_draft(
|
||||
input_batch,
|
||||
sampling_metadata,
|
||||
hidden_states,
|
||||
@@ -1014,6 +1011,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
|
||||
Reference in New Issue
Block a user