[Model Runner V2][Minor] Simplify PP logic (#38031)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -62,3 +62,12 @@ class IntermediateTensors:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
@staticmethod
|
||||
def empty_like(
|
||||
intermediate_tensors: "IntermediateTensors",
|
||||
) -> "IntermediateTensors":
|
||||
tensors = {
|
||||
k: torch.empty_like(v) for k, v in intermediate_tensors.tensors.items()
|
||||
}
|
||||
return IntermediateTensors(tensors)
|
||||
|
||||
@@ -94,13 +94,8 @@ class CudaGraphManager:
|
||||
self.decode_query_len = decode_query_len
|
||||
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||
if self.pp_size > 1:
|
||||
self.is_first_pp_rank = get_pp_group().is_first_rank
|
||||
self.is_last_pp_rank = get_pp_group().is_last_rank
|
||||
else:
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
|
||||
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
|
||||
@@ -371,14 +366,11 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
self.aux_hidden_states[i][:num_tokens] = aux
|
||||
else:
|
||||
# Non-last PP rank.
|
||||
assert isinstance(model_output, IntermediateTensors)
|
||||
intermediate_tensors = model_output
|
||||
assert isinstance(intermediate_tensors, IntermediateTensors)
|
||||
if self.intermediate_tensors is None:
|
||||
self.intermediate_tensors = IntermediateTensors(
|
||||
{
|
||||
k: torch.empty_like(v)
|
||||
for k, v in intermediate_tensors.tensors.items()
|
||||
}
|
||||
self.intermediate_tensors = IntermediateTensors.empty_like(
|
||||
intermediate_tensors
|
||||
)
|
||||
for k, v in intermediate_tensors.tensors.items():
|
||||
self.intermediate_tensors[k][:num_tokens] = v
|
||||
|
||||
@@ -132,14 +132,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.output_copy_event = torch.cuda.Event()
|
||||
|
||||
# Pipeline parallelism.
|
||||
self.pp_size = self.parallel_config.pipeline_parallel_size
|
||||
self.use_pp = self.pp_size > 1
|
||||
if self.use_pp:
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.is_first_pp_rank = get_pp_group().is_first_rank
|
||||
self.is_last_pp_rank = get_pp_group().is_last_rank
|
||||
else:
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
|
||||
# Persistent buffer for intermediate tensors (non-first PP ranks).
|
||||
self.intermediate_tensors: IntermediateTensors | None = None
|
||||
|
||||
@@ -179,7 +175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculative_config.method == "eagle3":
|
||||
# EAGLE3 may require auxiliary hidden states from target model outputs.
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
if self.pp_size > 1:
|
||||
if self.use_pp:
|
||||
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
|
||||
|
||||
# Draft tokens propagation - for spec-dec + struct outputs.
|
||||
@@ -270,8 +266,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Loading model from scratch...")
|
||||
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.vllm_config.model_config,
|
||||
vllm_config=self.vllm_config, model_config=self.vllm_config.model_config
|
||||
)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(
|
||||
@@ -1026,14 +1021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert intermediate_tensors is not None
|
||||
assert self.intermediate_tensors is not None
|
||||
n = input_batch.num_tokens_after_padding
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
model_inputs["intermediate_tensors"] = IntermediateTensors(
|
||||
{
|
||||
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
|
||||
for k, v in self.intermediate_tensors.tensors.items()
|
||||
},
|
||||
intermediate_tensors.kv_connector_output,
|
||||
}
|
||||
)
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
||||
del intermediate_tensors
|
||||
|
||||
# Run model.
|
||||
if batch_desc.cg_mode == CUDAGraphMode.FULL:
|
||||
|
||||
Reference in New Issue
Block a user