[Model Runner V2] Use unpadded num_tokens for PW CUDA graph attn metadata (#36626)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-10 09:30:56 -07:00
committed by GitHub
parent f83b933b84
commit f088a831dd
4 changed files with 14 additions and 3 deletions

View File

@@ -384,6 +384,7 @@ def prepare_inputs_to_capture(
attn_metadata = model_state.prepare_attn( attn_metadata = model_state.prepare_attn(
input_batch, input_batch,
CUDAGraphMode.NONE,
input_block_tables, input_block_tables,
slot_mappings, slot_mappings,
attn_groups, attn_groups,

View File

@@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert block_tables is not None assert block_tables is not None
attn_metadata = self.model_state.prepare_attn( attn_metadata = self.model_state.prepare_attn(
input_batch, input_batch,
batch_desc.cg_mode,
block_tables, block_tables,
slot_mappings, slot_mappings,
self.attn_groups, self.attn_groups,

View File

@@ -6,6 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
@@ -140,14 +141,20 @@ class DefaultModelState(ModelState):
def prepare_attn( def prepare_attn(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...], block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor, slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> dict[str, Any]: ) -> dict[str, Any]:
# Use padded sizes - padding is handled by model_runner.prepare_attn. if cudagraph_mode == CUDAGraphMode.FULL:
num_reqs = input_batch.num_reqs_after_padding # Use padded sizes - padding is handled by model_runner.prepare_attn.
num_tokens = input_batch.num_tokens_after_padding num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
# For piecewise cudagraphs and eager, use unpadded sizes.
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item() max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(

View File

@@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
@@ -59,6 +60,7 @@ class ModelState(ABC):
def prepare_attn( def prepare_attn(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...], block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor, slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],