diff --git a/vllm/sequence.py b/vllm/sequence.py index 3e12f148b..176306236 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -62,3 +62,12 @@ class IntermediateTensors: def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" + + @staticmethod + def empty_like( + intermediate_tensors: "IntermediateTensors", + ) -> "IntermediateTensors": + tensors = { + k: torch.empty_like(v) for k, v in intermediate_tensors.tensors.items() + } + return IntermediateTensors(tensors) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index b4c2b9579..d918131c6 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -94,13 +94,8 @@ class CudaGraphManager: self.decode_query_len = decode_query_len self.dp_size = vllm_config.parallel_config.data_parallel_size - self.pp_size = vllm_config.parallel_config.pipeline_parallel_size - if self.pp_size > 1: - self.is_first_pp_rank = get_pp_group().is_first_rank - self.is_last_pp_rank = get_pp_group().is_last_rank - else: - self.is_first_pp_rank = True - self.is_last_pp_rank = True + self.is_first_pp_rank = get_pp_group().is_first_rank + self.is_last_pp_rank = get_pp_group().is_last_rank self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {} self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None @@ -371,14 +366,11 @@ class ModelCudaGraphManager(CudaGraphManager): self.aux_hidden_states[i][:num_tokens] = aux else: # Non-last PP rank. + assert isinstance(model_output, IntermediateTensors) intermediate_tensors = model_output - assert isinstance(intermediate_tensors, IntermediateTensors) if self.intermediate_tensors is None: - self.intermediate_tensors = IntermediateTensors( - { - k: torch.empty_like(v) - for k, v in intermediate_tensors.tensors.items() - } + self.intermediate_tensors = IntermediateTensors.empty_like( + intermediate_tensors ) for k, v in intermediate_tensors.tensors.items(): self.intermediate_tensors[k][:num_tokens] = v diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index f58d3c7dc..acded972a 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -132,14 +132,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.output_copy_event = torch.cuda.Event() # Pipeline parallelism. - self.pp_size = self.parallel_config.pipeline_parallel_size - self.use_pp = self.pp_size > 1 - if self.use_pp: - self.is_first_pp_rank = get_pp_group().is_first_rank - self.is_last_pp_rank = get_pp_group().is_last_rank - else: - self.is_first_pp_rank = True - self.is_last_pp_rank = True + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + self.is_first_pp_rank = get_pp_group().is_first_rank + self.is_last_pp_rank = get_pp_group().is_last_rank + # Persistent buffer for intermediate tensors (non-first PP ranks). self.intermediate_tensors: IntermediateTensors | None = None @@ -179,7 +175,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.speculative_config.method == "eagle3": # EAGLE3 may require auxiliary hidden states from target model outputs. self.use_aux_hidden_state_outputs = True - if self.pp_size > 1: + if self.use_pp: raise ValueError("EAGLE3 with pipeline parallel is not supported.") # Draft tokens propagation - for spec-dec + struct outputs. @@ -270,8 +266,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): logger.info("Loading model from scratch...") self.model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.vllm_config.model_config, + vllm_config=self.vllm_config, model_config=self.vllm_config.model_config ) if self.lora_config: self.model = self.load_lora_model( @@ -1026,14 +1021,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert intermediate_tensors is not None assert self.intermediate_tensors is not None n = input_batch.num_tokens_after_padding - intermediate_tensors = IntermediateTensors( + model_inputs["intermediate_tensors"] = IntermediateTensors( { k: v[:n].copy_(intermediate_tensors.tensors[k][:n]) for k, v in self.intermediate_tensors.tensors.items() - }, - intermediate_tensors.kv_connector_output, + } ) - model_inputs["intermediate_tensors"] = intermediate_tensors + del intermediate_tensors # Run model. if batch_desc.cg_mode == CUDAGraphMode.FULL: