diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 2ec3cb2a2..202470c7b 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -384,6 +384,7 @@ def prepare_inputs_to_capture( attn_metadata = model_state.prepare_attn( input_batch, + CUDAGraphMode.NONE, input_block_tables, slot_mappings, attn_groups, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 41c2f3704..58ff78b12 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert block_tables is not None attn_metadata = self.model_state.prepare_attn( input_batch, + batch_desc.cg_mode, block_tables, slot_mappings, self.attn_groups, diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 770c65049..6d24c3663 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode from vllm.v1.core.sched.output import NewRequestData from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_attn_metadata @@ -140,14 +141,20 @@ class DefaultModelState(ModelState): def prepare_attn( self, input_batch: InputBatch, + cudagraph_mode: CUDAGraphMode, block_tables: tuple[torch.Tensor, ...], slot_mappings: torch.Tensor, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, ) -> dict[str, Any]: - # Use padded sizes - padding is handled by model_runner.prepare_attn. - num_reqs = input_batch.num_reqs_after_padding - num_tokens = input_batch.num_tokens_after_padding + if cudagraph_mode == CUDAGraphMode.FULL: + # Use padded sizes - padding is handled by model_runner.prepare_attn. + num_reqs = input_batch.num_reqs_after_padding + num_tokens = input_batch.num_tokens_after_padding + else: + # For piecewise cudagraphs and eager, use unpadded sizes. + num_reqs = input_batch.num_reqs + num_tokens = input_batch.num_tokens query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) max_query_len = input_batch.num_scheduled_tokens.max().item() attn_metadata = build_attn_metadata( diff --git a/vllm/v1/worker/gpu/model_states/interface.py b/vllm/v1/worker/gpu/model_states/interface.py index d5a25710c..064cfa195 100644 --- a/vllm/v1/worker/gpu/model_states/interface.py +++ b/vllm/v1/worker/gpu/model_states/interface.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode from vllm.v1.core.sched.output import NewRequestData from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.input_batch import InputBatch @@ -59,6 +60,7 @@ class ModelState(ABC): def prepare_attn( self, input_batch: InputBatch, + cudagraph_mode: CUDAGraphMode, block_tables: tuple[torch.Tensor, ...], slot_mappings: torch.Tensor, attn_groups: list[list[AttentionGroup]],