[Model Runner V2][Minor] Simplify PP logic (#38031)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -62,3 +62,12 @@ class IntermediateTensors:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"IntermediateTensors(tensors={self.tensors})"
|
return f"IntermediateTensors(tensors={self.tensors})"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty_like(
|
||||||
|
intermediate_tensors: "IntermediateTensors",
|
||||||
|
) -> "IntermediateTensors":
|
||||||
|
tensors = {
|
||||||
|
k: torch.empty_like(v) for k, v in intermediate_tensors.tensors.items()
|
||||||
|
}
|
||||||
|
return IntermediateTensors(tensors)
|
||||||
|
|||||||
@@ -94,13 +94,8 @@ class CudaGraphManager:
|
|||||||
self.decode_query_len = decode_query_len
|
self.decode_query_len = decode_query_len
|
||||||
|
|
||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
|
||||||
if self.pp_size > 1:
|
|
||||||
self.is_first_pp_rank = get_pp_group().is_first_rank
|
self.is_first_pp_rank = get_pp_group().is_first_rank
|
||||||
self.is_last_pp_rank = get_pp_group().is_last_rank
|
self.is_last_pp_rank = get_pp_group().is_last_rank
|
||||||
else:
|
|
||||||
self.is_first_pp_rank = True
|
|
||||||
self.is_last_pp_rank = True
|
|
||||||
|
|
||||||
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
|
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
|
||||||
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
|
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
|
||||||
@@ -371,14 +366,11 @@ class ModelCudaGraphManager(CudaGraphManager):
|
|||||||
self.aux_hidden_states[i][:num_tokens] = aux
|
self.aux_hidden_states[i][:num_tokens] = aux
|
||||||
else:
|
else:
|
||||||
# Non-last PP rank.
|
# Non-last PP rank.
|
||||||
|
assert isinstance(model_output, IntermediateTensors)
|
||||||
intermediate_tensors = model_output
|
intermediate_tensors = model_output
|
||||||
assert isinstance(intermediate_tensors, IntermediateTensors)
|
|
||||||
if self.intermediate_tensors is None:
|
if self.intermediate_tensors is None:
|
||||||
self.intermediate_tensors = IntermediateTensors(
|
self.intermediate_tensors = IntermediateTensors.empty_like(
|
||||||
{
|
intermediate_tensors
|
||||||
k: torch.empty_like(v)
|
|
||||||
for k, v in intermediate_tensors.tensors.items()
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
for k, v in intermediate_tensors.tensors.items():
|
for k, v in intermediate_tensors.tensors.items():
|
||||||
self.intermediate_tensors[k][:num_tokens] = v
|
self.intermediate_tensors[k][:num_tokens] = v
|
||||||
|
|||||||
@@ -132,14 +132,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.output_copy_event = torch.cuda.Event()
|
self.output_copy_event = torch.cuda.Event()
|
||||||
|
|
||||||
# Pipeline parallelism.
|
# Pipeline parallelism.
|
||||||
self.pp_size = self.parallel_config.pipeline_parallel_size
|
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||||
self.use_pp = self.pp_size > 1
|
|
||||||
if self.use_pp:
|
|
||||||
self.is_first_pp_rank = get_pp_group().is_first_rank
|
self.is_first_pp_rank = get_pp_group().is_first_rank
|
||||||
self.is_last_pp_rank = get_pp_group().is_last_rank
|
self.is_last_pp_rank = get_pp_group().is_last_rank
|
||||||
else:
|
|
||||||
self.is_first_pp_rank = True
|
|
||||||
self.is_last_pp_rank = True
|
|
||||||
# Persistent buffer for intermediate tensors (non-first PP ranks).
|
# Persistent buffer for intermediate tensors (non-first PP ranks).
|
||||||
self.intermediate_tensors: IntermediateTensors | None = None
|
self.intermediate_tensors: IntermediateTensors | None = None
|
||||||
|
|
||||||
@@ -179,7 +175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.speculative_config.method == "eagle3":
|
if self.speculative_config.method == "eagle3":
|
||||||
# EAGLE3 may require auxiliary hidden states from target model outputs.
|
# EAGLE3 may require auxiliary hidden states from target model outputs.
|
||||||
self.use_aux_hidden_state_outputs = True
|
self.use_aux_hidden_state_outputs = True
|
||||||
if self.pp_size > 1:
|
if self.use_pp:
|
||||||
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
|
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
|
||||||
|
|
||||||
# Draft tokens propagation - for spec-dec + struct outputs.
|
# Draft tokens propagation - for spec-dec + struct outputs.
|
||||||
@@ -270,8 +266,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logger.info("Loading model from scratch...")
|
logger.info("Loading model from scratch...")
|
||||||
|
|
||||||
self.model = model_loader.load_model(
|
self.model = model_loader.load_model(
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config, model_config=self.vllm_config.model_config
|
||||||
model_config=self.vllm_config.model_config,
|
|
||||||
)
|
)
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.model = self.load_lora_model(
|
self.model = self.load_lora_model(
|
||||||
@@ -1026,14 +1021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
assert self.intermediate_tensors is not None
|
assert self.intermediate_tensors is not None
|
||||||
n = input_batch.num_tokens_after_padding
|
n = input_batch.num_tokens_after_padding
|
||||||
intermediate_tensors = IntermediateTensors(
|
model_inputs["intermediate_tensors"] = IntermediateTensors(
|
||||||
{
|
{
|
||||||
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
|
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
|
||||||
for k, v in self.intermediate_tensors.tensors.items()
|
for k, v in self.intermediate_tensors.tensors.items()
|
||||||
},
|
}
|
||||||
intermediate_tensors.kv_connector_output,
|
|
||||||
)
|
)
|
||||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
del intermediate_tensors
|
||||||
|
|
||||||
# Run model.
|
# Run model.
|
||||||
if batch_desc.cg_mode == CUDAGraphMode.FULL:
|
if batch_desc.cg_mode == CUDAGraphMode.FULL:
|
||||||
|
|||||||
Reference in New Issue
Block a user