[Model Runner V2] Use unpadded num_tokens for PW CUDA graph attn metadata (#36626)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -384,6 +384,7 @@ def prepare_inputs_to_capture(
|
|||||||
|
|
||||||
attn_metadata = model_state.prepare_attn(
|
attn_metadata = model_state.prepare_attn(
|
||||||
input_batch,
|
input_batch,
|
||||||
|
CUDAGraphMode.NONE,
|
||||||
input_block_tables,
|
input_block_tables,
|
||||||
slot_mappings,
|
slot_mappings,
|
||||||
attn_groups,
|
attn_groups,
|
||||||
|
|||||||
@@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
assert block_tables is not None
|
assert block_tables is not None
|
||||||
attn_metadata = self.model_state.prepare_attn(
|
attn_metadata = self.model_state.prepare_attn(
|
||||||
input_batch,
|
input_batch,
|
||||||
|
batch_desc.cg_mode,
|
||||||
block_tables,
|
block_tables,
|
||||||
slot_mappings,
|
slot_mappings,
|
||||||
self.attn_groups,
|
self.attn_groups,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
from vllm.v1.core.sched.output import NewRequestData
|
from vllm.v1.core.sched.output import NewRequestData
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||||
@@ -140,14 +141,20 @@ class DefaultModelState(ModelState):
|
|||||||
def prepare_attn(
|
def prepare_attn(
|
||||||
self,
|
self,
|
||||||
input_batch: InputBatch,
|
input_batch: InputBatch,
|
||||||
|
cudagraph_mode: CUDAGraphMode,
|
||||||
block_tables: tuple[torch.Tensor, ...],
|
block_tables: tuple[torch.Tensor, ...],
|
||||||
slot_mappings: torch.Tensor,
|
slot_mappings: torch.Tensor,
|
||||||
attn_groups: list[list[AttentionGroup]],
|
attn_groups: list[list[AttentionGroup]],
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
# Use padded sizes - padding is handled by model_runner.prepare_attn.
|
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||||
num_reqs = input_batch.num_reqs_after_padding
|
# Use padded sizes - padding is handled by model_runner.prepare_attn.
|
||||||
num_tokens = input_batch.num_tokens_after_padding
|
num_reqs = input_batch.num_reqs_after_padding
|
||||||
|
num_tokens = input_batch.num_tokens_after_padding
|
||||||
|
else:
|
||||||
|
# For piecewise cudagraphs and eager, use unpadded sizes.
|
||||||
|
num_reqs = input_batch.num_reqs
|
||||||
|
num_tokens = input_batch.num_tokens
|
||||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||||
attn_metadata = build_attn_metadata(
|
attn_metadata = build_attn_metadata(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
from vllm.v1.core.sched.output import NewRequestData
|
from vllm.v1.core.sched.output import NewRequestData
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||||
@@ -59,6 +60,7 @@ class ModelState(ABC):
|
|||||||
def prepare_attn(
|
def prepare_attn(
|
||||||
self,
|
self,
|
||||||
input_batch: InputBatch,
|
input_batch: InputBatch,
|
||||||
|
cudagraph_mode: CUDAGraphMode,
|
||||||
block_tables: tuple[torch.Tensor, ...],
|
block_tables: tuple[torch.Tensor, ...],
|
||||||
slot_mappings: torch.Tensor,
|
slot_mappings: torch.Tensor,
|
||||||
attn_groups: list[list[AttentionGroup]],
|
attn_groups: list[list[AttentionGroup]],
|
||||||
|
|||||||
Reference in New Issue
Block a user