[Model Runner V2] Do not initialize sampler for non-last PP ranks (#36824)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-11 20:55:28 -07:00
committed by GitHub
parent 2ef69456f5
commit 2f8b4ce0c0
3 changed files with 75 additions and 50 deletions

View File

@@ -438,17 +438,20 @@ def _post_update_kernel(
for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
tl.store(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
token_id,
)
if output_bin_counts_ptr is not None:
token_ptr = (
output_bin_counts_ptr
+ req_state_idx * output_bin_counts_stride
+ token_id
)
count = tl.load(token_ptr)
tl.store(token_ptr, count + 1)
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
@@ -467,7 +470,7 @@ def post_update(
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
output_bin_counts: torch.Tensor | None,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
@@ -487,7 +490,7 @@ def post_update(
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
output_bin_counts.stride(0) if output_bin_counts is not None else 0,
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,

View File

@@ -183,6 +183,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None
# General request states.
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
@@ -199,20 +203,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_tokens=self.max_num_tokens,
device=self.device,
)
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.sampler: Sampler | None = None
self.rejection_sampler: RejectionSampler | None = None
self.prompt_logprobs_worker: PromptLogprobsWorker | None = None
self.structured_outputs_worker: StructuredOutputsWorker | None = None
if self.is_last_pp_rank and not self.is_pooling_model:
# Initialize sampling-related workers.
# These components are only set up on the last PP rank and
# for generative (non-pooling) models.
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)
# CUDA graphs.
self.decode_query_len = self.num_speculative_steps + 1
@@ -222,21 +240,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len,
)
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)
# LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None
# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: ExecuteModelState | None = None
@@ -248,8 +256,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tasks: list[SupportedTask] = []
if self.model_config.runner_type == "generate":
tasks.extend(self.model_state.get_supported_generation_tasks())
if self.pooling_runner is not None:
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
if self.is_pooling_model:
# Do not rely on pooling_runner here, since this information is needed
# on the first PP rank, while pooling_runner is only initialized
# on the last PP rank.
tasks.extend(PoolingRunner.get_supported_tasks(self.model))
return tuple(tasks)
def load_model(self, *args, **kwargs) -> None:
@@ -289,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model_state = init_model_state(
self.vllm_config, self.model, self.encoder_cache, self.device
)
if self.is_pooling_model:
if self.is_pooling_model and self.is_last_pp_rank:
self.pooling_runner = PoolingRunner(self.model)
def get_model(self) -> nn.Module:
@@ -420,6 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
assert self.sampler is not None
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
@@ -457,10 +469,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(
logits,
dummy_input_batch,
)
assert self.sampler is not None
self.sampler(logits, dummy_input_batch)
@torch.inference_mode()
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
@@ -558,7 +568,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.remove_request(req_id)
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
@@ -589,18 +600,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
if new_req_data.sampling_params is not None:
if self.is_last_pp_rank and new_req_data.sampling_params is not None:
assert self.sampler is not None
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
assert self.prompt_logprobs_worker is not None
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes()
self.model_state.apply_staged_writes()
if self.sampler is not None:
self.sampler.apply_staged_writes()
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests.
@@ -788,6 +802,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
assert self.structured_outputs_worker is not None
self.structured_outputs_worker.apply_grammar_bitmask(
logits,
input_batch,
@@ -797,12 +812,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
sampler_output = self.sampler(
logits,
input_batch,
)
assert self.sampler is not None
sampler_output = self.sampler(logits, input_batch)
else:
# Rejection sampling for spec decoding.
assert self.rejection_sampler is not None
sampler_output = self.rejection_sampler(
logits,
input_batch,
@@ -831,11 +845,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_rejected: torch.Tensor,
) -> None:
# Update the number of computed tokens.
if self.is_last_pp_rank:
assert self.sampler is not None
output_bin_counts = self.sampler.penalties_state.output_bin_counts
else:
output_bin_counts = None
post_update(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens,
self.sampler.penalties_state.output_bin_counts,
output_bin_counts,
sampled_tokens,
num_sampled,
num_rejected,
@@ -1076,6 +1095,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Broadcast to non-last PP ranks (handles spec decode multi-token).
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
assert self.prompt_logprobs_worker is not None
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
self.model.compute_logits,
hidden_states,
@@ -1115,6 +1135,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
)
if self.speculator is not None:
assert self.sampler is not None
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,

View File

@@ -19,10 +19,11 @@ class PoolingRunner:
def __init__(self, model: nn.Module):
self.model = cast(VllmModelForPooling, model)
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
if not is_pooling_model(self.model):
@staticmethod
def get_supported_tasks(model: nn.Module) -> list[PoolingTask]:
if not is_pooling_model(model):
return []
assert "embed" in self.model.pooler.get_supported_tasks()
assert "embed" in model.pooler.get_supported_tasks()
return ["embed"]
def pool(