[Model Runner V2] Use unpadded num_tokens for PW CUDA graph attn metadata (#36626)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -384,6 +384,7 @@ def prepare_inputs_to_capture(
|
||||
|
||||
attn_metadata = model_state.prepare_attn(
|
||||
input_batch,
|
||||
CUDAGraphMode.NONE,
|
||||
input_block_tables,
|
||||
slot_mappings,
|
||||
attn_groups,
|
||||
|
||||
@@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert block_tables is not None
|
||||
attn_metadata = self.model_state.prepare_attn(
|
||||
input_batch,
|
||||
batch_desc.cg_mode,
|
||||
block_tables,
|
||||
slot_mappings,
|
||||
self.attn_groups,
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
@@ -140,14 +141,20 @@ class DefaultModelState(ModelState):
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
# Use padded sizes - padding is handled by model_runner.prepare_attn.
|
||||
num_reqs = input_batch.num_reqs_after_padding
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Use padded sizes - padding is handled by model_runner.prepare_attn.
|
||||
num_reqs = input_batch.num_reqs_after_padding
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
else:
|
||||
# For piecewise cudagraphs and eager, use unpadded sizes.
|
||||
num_reqs = input_batch.num_reqs
|
||||
num_tokens = input_batch.num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||
attn_metadata = build_attn_metadata(
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
@@ -59,6 +60,7 @@ class ModelState(ABC):
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
|
||||
Reference in New Issue
Block a user