diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index c101fad24..fcc4631f2 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -12,6 +12,7 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode +from vllm.distributed.parallel_state import prepare_communication_buffer_for_model from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader @@ -206,6 +207,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): time_after_load - time_before_load, ) + prepare_communication_buffer_for_model(self.model) + if self.do_spec_decode: + speculator_model = getattr(self.speculator, "model", None) + if speculator_model is not None: + prepare_communication_buffer_for_model(speculator_model) + def get_model(self) -> nn.Module: return self.model