[Model Runner V2] Use unpadded num_tokens for PW CUDA graph attn metadata (#36626)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-10 09:30:56 -07:00
committed by GitHub
parent f83b933b84
commit f088a831dd
4 changed files with 14 additions and 3 deletions

View File

@@ -384,6 +384,7 @@ def prepare_inputs_to_capture(
attn_metadata = model_state.prepare_attn(
input_batch,
CUDAGraphMode.NONE,
input_block_tables,
slot_mappings,
attn_groups,

View File

@@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert block_tables is not None
attn_metadata = self.model_state.prepare_attn(
input_batch,
batch_desc.cg_mode,
block_tables,
slot_mappings,
self.attn_groups,

View File

@@ -6,6 +6,7 @@ import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
@@ -140,14 +141,20 @@ class DefaultModelState(ModelState):
def prepare_attn(
self,
input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
if cudagraph_mode == CUDAGraphMode.FULL:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
# For piecewise cudagraphs and eager, use unpadded sizes.
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(

View File

@@ -7,6 +7,7 @@ import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.input_batch import InputBatch
@@ -59,6 +60,7 @@ class ModelState(ABC):
def prepare_attn(
self,
input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],