[Model Runner V2] Do not initialize sampler for non-last PP ranks (#36824)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -438,17 +438,20 @@ def _post_update_kernel(
|
||||
|
||||
for i in range(num_sampled):
|
||||
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
|
||||
token_ptr = (
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
|
||||
)
|
||||
count = tl.load(token_ptr)
|
||||
count += 1
|
||||
tl.store(token_ptr, count)
|
||||
tl.store(
|
||||
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
|
||||
token_id,
|
||||
)
|
||||
|
||||
if output_bin_counts_ptr is not None:
|
||||
token_ptr = (
|
||||
output_bin_counts_ptr
|
||||
+ req_state_idx * output_bin_counts_stride
|
||||
+ token_id
|
||||
)
|
||||
count = tl.load(token_ptr)
|
||||
tl.store(token_ptr, count + 1)
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + req_id)
|
||||
query_end = tl.load(query_start_loc_ptr + req_id + 1)
|
||||
query_len = query_end - query_start
|
||||
@@ -467,7 +470,7 @@ def post_update(
|
||||
# [max_num_reqs]
|
||||
last_sampled_tokens: torch.Tensor,
|
||||
# [max_num_reqs, vocab_size]
|
||||
output_bin_counts: torch.Tensor,
|
||||
output_bin_counts: torch.Tensor | None,
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_tokens: torch.Tensor,
|
||||
# [num_reqs]
|
||||
@@ -487,7 +490,7 @@ def post_update(
|
||||
num_computed_tokens,
|
||||
last_sampled_tokens,
|
||||
output_bin_counts,
|
||||
output_bin_counts.stride(0),
|
||||
output_bin_counts.stride(0) if output_bin_counts is not None else 0,
|
||||
sampled_tokens,
|
||||
sampled_tokens.stride(0),
|
||||
num_sampled,
|
||||
|
||||
@@ -183,6 +183,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Draft tokens propagation - for spec-dec + struct outputs.
|
||||
self.draft_tokens_handler = DraftTokensHandler(self.device)
|
||||
|
||||
# Pooling models.
|
||||
self.is_pooling_model = self.model_config.runner_type == "pooling"
|
||||
self.pooling_runner: PoolingRunner | None = None
|
||||
|
||||
# General request states.
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
@@ -199,20 +203,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
)
|
||||
self.sampler = Sampler(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
req_states=self.req_states,
|
||||
logprobs_mode=self.model_config.logprobs_mode,
|
||||
num_speculative_tokens=self.num_speculative_steps + 1,
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler(
|
||||
self.sampler,
|
||||
num_speculative_steps=self.num_speculative_steps,
|
||||
use_strict_rejection_sampling=use_strict_rejection_sampling,
|
||||
)
|
||||
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
|
||||
|
||||
self.sampler: Sampler | None = None
|
||||
self.rejection_sampler: RejectionSampler | None = None
|
||||
self.prompt_logprobs_worker: PromptLogprobsWorker | None = None
|
||||
self.structured_outputs_worker: StructuredOutputsWorker | None = None
|
||||
if self.is_last_pp_rank and not self.is_pooling_model:
|
||||
# Initialize sampling-related workers.
|
||||
# These components are only set up on the last PP rank and
|
||||
# for generative (non-pooling) models.
|
||||
self.sampler = Sampler(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
req_states=self.req_states,
|
||||
logprobs_mode=self.model_config.logprobs_mode,
|
||||
num_speculative_tokens=self.num_speculative_steps + 1,
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler(
|
||||
self.sampler,
|
||||
num_speculative_steps=self.num_speculative_steps,
|
||||
use_strict_rejection_sampling=use_strict_rejection_sampling,
|
||||
)
|
||||
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
|
||||
self.structured_outputs_worker = StructuredOutputsWorker(
|
||||
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# CUDA graphs.
|
||||
self.decode_query_len = self.num_speculative_steps + 1
|
||||
@@ -222,21 +240,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.compilation_config.cudagraph_mode,
|
||||
decode_query_len=self.decode_query_len,
|
||||
)
|
||||
# Structured outputs worker.
|
||||
self.structured_outputs_worker = StructuredOutputsWorker(
|
||||
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
)
|
||||
# LoRA-related workers.
|
||||
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
|
||||
# KV Connector if configured.
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
# Pooling models.
|
||||
self.is_pooling_model = self.model_config.runner_type == "pooling"
|
||||
self.pooling_runner: PoolingRunner | None = None
|
||||
|
||||
# For transferring state from execute_model to subsequent sample_tokens call.
|
||||
self.execute_model_state: ExecuteModelState | None = None
|
||||
|
||||
@@ -248,8 +256,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
tasks: list[SupportedTask] = []
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.extend(self.model_state.get_supported_generation_tasks())
|
||||
if self.pooling_runner is not None:
|
||||
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
|
||||
if self.is_pooling_model:
|
||||
# Do not rely on pooling_runner here, since this information is needed
|
||||
# on the first PP rank, while pooling_runner is only initialized
|
||||
# on the last PP rank.
|
||||
tasks.extend(PoolingRunner.get_supported_tasks(self.model))
|
||||
return tuple(tasks)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
@@ -289,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model_state = init_model_state(
|
||||
self.vllm_config, self.model, self.encoder_cache, self.device
|
||||
)
|
||||
if self.is_pooling_model:
|
||||
if self.is_pooling_model and self.is_last_pp_rank:
|
||||
self.pooling_runner = PoolingRunner(self.model)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
@@ -420,6 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
||||
if self.speculator is not None:
|
||||
assert self.sampler is not None
|
||||
self.speculator.propose(
|
||||
input_batch=input_batch,
|
||||
attn_metadata=attn_metadata,
|
||||
@@ -457,10 +469,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
|
||||
# top_k, top_p, and logprobs, using less GPU memory than what is possible
|
||||
# during actual execution.
|
||||
self.sampler(
|
||||
logits,
|
||||
dummy_input_batch,
|
||||
)
|
||||
assert self.sampler is not None
|
||||
self.sampler(logits, dummy_input_batch)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
@@ -558,7 +568,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.req_states.remove_request(req_id)
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.remove_request(req_id)
|
||||
self.prompt_logprobs_worker.remove_request(req_id)
|
||||
if self.prompt_logprobs_worker is not None:
|
||||
self.prompt_logprobs_worker.remove_request(req_id)
|
||||
self.lora_state.remove_request(req_id)
|
||||
|
||||
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
@@ -589,18 +600,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
|
||||
|
||||
if new_req_data.sampling_params is not None:
|
||||
if self.is_last_pp_rank and new_req_data.sampling_params is not None:
|
||||
assert self.sampler is not None
|
||||
self.sampler.add_request(
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
assert self.prompt_logprobs_worker is not None
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes()
|
||||
self.model_state.apply_staged_writes()
|
||||
if self.sampler is not None:
|
||||
self.sampler.apply_staged_writes()
|
||||
|
||||
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# Add new blocks for the existing requests.
|
||||
@@ -788,6 +802,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if grammar_output is not None:
|
||||
# Apply grammar bitmask to the logits in-place.
|
||||
assert self.structured_outputs_worker is not None
|
||||
self.structured_outputs_worker.apply_grammar_bitmask(
|
||||
logits,
|
||||
input_batch,
|
||||
@@ -797,12 +812,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if input_batch.num_draft_tokens == 0:
|
||||
# No draft tokens (common case).
|
||||
sampler_output = self.sampler(
|
||||
logits,
|
||||
input_batch,
|
||||
)
|
||||
assert self.sampler is not None
|
||||
sampler_output = self.sampler(logits, input_batch)
|
||||
else:
|
||||
# Rejection sampling for spec decoding.
|
||||
assert self.rejection_sampler is not None
|
||||
sampler_output = self.rejection_sampler(
|
||||
logits,
|
||||
input_batch,
|
||||
@@ -831,11 +845,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_rejected: torch.Tensor,
|
||||
) -> None:
|
||||
# Update the number of computed tokens.
|
||||
if self.is_last_pp_rank:
|
||||
assert self.sampler is not None
|
||||
output_bin_counts = self.sampler.penalties_state.output_bin_counts
|
||||
else:
|
||||
output_bin_counts = None
|
||||
post_update(
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
self.req_states.last_sampled_tokens,
|
||||
self.sampler.penalties_state.output_bin_counts,
|
||||
output_bin_counts,
|
||||
sampled_tokens,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
@@ -1076,6 +1095,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Broadcast to non-last PP ranks (handles spec decode multi-token).
|
||||
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
|
||||
|
||||
assert self.prompt_logprobs_worker is not None
|
||||
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
|
||||
self.model.compute_logits,
|
||||
hidden_states,
|
||||
@@ -1115,6 +1135,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
|
||||
)
|
||||
if self.speculator is not None:
|
||||
assert self.sampler is not None
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
attn_metadata,
|
||||
|
||||
@@ -19,10 +19,11 @@ class PoolingRunner:
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = cast(VllmModelForPooling, model)
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
if not is_pooling_model(self.model):
|
||||
@staticmethod
|
||||
def get_supported_tasks(model: nn.Module) -> list[PoolingTask]:
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
assert "embed" in self.model.pooler.get_supported_tasks()
|
||||
assert "embed" in model.pooler.get_supported_tasks()
|
||||
return ["embed"]
|
||||
|
||||
def pool(
|
||||
|
||||
Reference in New Issue
Block a user