Add logging for cudagraph related info (#29825)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from collections import Counter
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -22,6 +23,99 @@ from vllm.utils.torch_utils import weak_ref_tensors
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class CUDAGraphStat:
|
||||||
|
num_unpadded_tokens: int
|
||||||
|
num_padded_tokens: int
|
||||||
|
num_paddings: int
|
||||||
|
runtime_mode: str
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAGraphLogging:
|
||||||
|
"""Aggregate and log cudagraph metrics"""
|
||||||
|
|
||||||
|
COLUMN_HEADERS = [
|
||||||
|
"Unpadded Tokens",
|
||||||
|
"Padded Tokens",
|
||||||
|
"Num Paddings",
|
||||||
|
"Runtime Mode",
|
||||||
|
"Count",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
|
||||||
|
self.reset()
|
||||||
|
self.cg_mode = str(cg_mode)
|
||||||
|
self.cg_capture_sizes = str(cg_capture_sizes or [])
|
||||||
|
|
||||||
|
self.settings_header = (
|
||||||
|
"**CUDAGraph Config Settings:**\n\n"
|
||||||
|
f"- Mode: {self.cg_mode}\n"
|
||||||
|
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
|
||||||
|
"**CUDAGraph Stats:**\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.stats = []
|
||||||
|
|
||||||
|
def observe(self, cudagraph_stat: CUDAGraphStat):
|
||||||
|
self.stats.append(cudagraph_stat)
|
||||||
|
|
||||||
|
def generate_metric_table(self) -> str:
|
||||||
|
stats_counts = Counter(self.stats)
|
||||||
|
|
||||||
|
# Convert stats to rows of strings, in descending order of observed frequencies
|
||||||
|
rows = []
|
||||||
|
for stat, count in sorted(
|
||||||
|
stats_counts.items(), key=lambda item: item[1], reverse=True
|
||||||
|
):
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
str(stat.num_unpadded_tokens),
|
||||||
|
str(stat.num_padded_tokens),
|
||||||
|
str(stat.num_paddings),
|
||||||
|
stat.runtime_mode,
|
||||||
|
str(count),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate column widths (max of header and data)
|
||||||
|
col_widths = []
|
||||||
|
for i, header_text in enumerate(self.COLUMN_HEADERS):
|
||||||
|
max_width = len(header_text)
|
||||||
|
for row in rows:
|
||||||
|
max_width = max(max_width, len(row[i]))
|
||||||
|
col_widths.append(max_width)
|
||||||
|
|
||||||
|
table_header_list = [
|
||||||
|
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
|
||||||
|
]
|
||||||
|
table_header = "| " + " | ".join(table_header_list) + " |\n"
|
||||||
|
|
||||||
|
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
|
||||||
|
|
||||||
|
# Create data rows with proper alignment
|
||||||
|
data_rows = []
|
||||||
|
for row in rows:
|
||||||
|
formatted_row = [
|
||||||
|
str(val).ljust(width) for val, width in zip(row, col_widths)
|
||||||
|
]
|
||||||
|
data_rows.append("| " + " | ".join(formatted_row) + " |")
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.settings_header
|
||||||
|
+ table_header
|
||||||
|
+ table_separator
|
||||||
|
+ "\n".join(data_rows)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def log(self, log_fn=logger.info):
|
||||||
|
if not self.stats:
|
||||||
|
return
|
||||||
|
log_fn(self.generate_metric_table())
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class CUDAGraphEntry:
|
class CUDAGraphEntry:
|
||||||
batch_descriptor: BatchDescriptor
|
batch_descriptor: BatchDescriptor
|
||||||
|
|||||||
@@ -55,6 +55,10 @@ class ObservabilityConfig:
|
|||||||
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
|
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
|
||||||
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
|
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
|
||||||
|
|
||||||
|
cudagraph_metrics: bool = False
|
||||||
|
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
|
||||||
|
dispatch modes, and their observed frequencies at every logging interval)."""
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def collect_model_forward_time(self) -> bool:
|
def collect_model_forward_time(self) -> bool:
|
||||||
"""Whether to collect model forward time for the request."""
|
"""Whether to collect model forward time for the request."""
|
||||||
|
|||||||
@@ -518,6 +518,7 @@ class EngineArgs:
|
|||||||
kv_cache_metrics_sample: float = get_field(
|
kv_cache_metrics_sample: float = get_field(
|
||||||
ObservabilityConfig, "kv_cache_metrics_sample"
|
ObservabilityConfig, "kv_cache_metrics_sample"
|
||||||
)
|
)
|
||||||
|
cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
|
||||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||||
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
|
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
|
||||||
|
|
||||||
@@ -1021,6 +1022,10 @@ class EngineArgs:
|
|||||||
"--kv-cache-metrics-sample",
|
"--kv-cache-metrics-sample",
|
||||||
**observability_kwargs["kv_cache_metrics_sample"],
|
**observability_kwargs["kv_cache_metrics_sample"],
|
||||||
)
|
)
|
||||||
|
observability_group.add_argument(
|
||||||
|
"--cudagraph-metrics",
|
||||||
|
**observability_kwargs["cudagraph_metrics"],
|
||||||
|
)
|
||||||
|
|
||||||
# Scheduler arguments
|
# Scheduler arguments
|
||||||
scheduler_kwargs = get_kwargs(SchedulerConfig)
|
scheduler_kwargs = get_kwargs(SchedulerConfig)
|
||||||
@@ -1698,6 +1703,7 @@ class EngineArgs:
|
|||||||
collect_detailed_traces=self.collect_detailed_traces,
|
collect_detailed_traces=self.collect_detailed_traces,
|
||||||
kv_cache_metrics=self.kv_cache_metrics,
|
kv_cache_metrics=self.kv_cache_metrics,
|
||||||
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
|
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
|
||||||
|
cudagraph_metrics=self.cudagraph_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compilation config overrides
|
# Compilation config overrides
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from collections.abc import Iterable
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.compilation.cuda_graph import CUDAGraphStat
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||||
ECConnectorMetadata,
|
ECConnectorMetadata,
|
||||||
@@ -1037,6 +1038,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
pooler_outputs = model_runner_output.pooler_output
|
pooler_outputs = model_runner_output.pooler_output
|
||||||
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
num_nans_in_logits = model_runner_output.num_nans_in_logits
|
||||||
kv_connector_output = model_runner_output.kv_connector_output
|
kv_connector_output = model_runner_output.kv_connector_output
|
||||||
|
cudagraph_stats = model_runner_output.cudagraph_stats
|
||||||
|
|
||||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||||
spec_decoding_stats: SpecDecodingStats | None = None
|
spec_decoding_stats: SpecDecodingStats | None = None
|
||||||
@@ -1219,7 +1221,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
finished_req_ids.clear()
|
finished_req_ids.clear()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
stats := self.make_stats(spec_decoding_stats, kv_connector_stats)
|
stats := self.make_stats(
|
||||||
|
spec_decoding_stats, kv_connector_stats, cudagraph_stats
|
||||||
|
)
|
||||||
) is not None:
|
) is not None:
|
||||||
# Return stats to only one of the front-ends.
|
# Return stats to only one of the front-ends.
|
||||||
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
|
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
|
||||||
@@ -1420,6 +1424,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self,
|
self,
|
||||||
spec_decoding_stats: SpecDecodingStats | None = None,
|
spec_decoding_stats: SpecDecodingStats | None = None,
|
||||||
kv_connector_stats: KVConnectorStats | None = None,
|
kv_connector_stats: KVConnectorStats | None = None,
|
||||||
|
cudagraph_stats: CUDAGraphStat | None = None,
|
||||||
) -> SchedulerStats | None:
|
) -> SchedulerStats | None:
|
||||||
if not self.log_stats:
|
if not self.log_stats:
|
||||||
return None
|
return None
|
||||||
@@ -1444,6 +1449,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
kv_cache_eviction_events=eviction_events,
|
kv_cache_eviction_events=eviction_events,
|
||||||
spec_decoding_stats=spec_stats,
|
spec_decoding_stats=spec_stats,
|
||||||
kv_connector_stats=connector_stats_payload,
|
kv_connector_stats=connector_stats_payload,
|
||||||
|
cudagraph_stats=cudagraph_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_spec_decoding_stats(
|
def make_spec_decoding_stats(
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import TypeAlias
|
|||||||
from prometheus_client import Counter, Gauge, Histogram
|
from prometheus_client import Counter, Gauge, Histogram
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.cuda_graph import CUDAGraphLogging
|
||||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
KVConnectorLogging,
|
KVConnectorLogging,
|
||||||
@@ -106,6 +107,12 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.spec_decoding_logging = SpecDecodingLogging()
|
self.spec_decoding_logging = SpecDecodingLogging()
|
||||||
kv_transfer_config = self.vllm_config.kv_transfer_config
|
kv_transfer_config = self.vllm_config.kv_transfer_config
|
||||||
self.kv_connector_logging = KVConnectorLogging(kv_transfer_config)
|
self.kv_connector_logging = KVConnectorLogging(kv_transfer_config)
|
||||||
|
self.cudagraph_logging = None
|
||||||
|
if self.vllm_config.observability_config.cudagraph_metrics:
|
||||||
|
self.cudagraph_logging = CUDAGraphLogging(
|
||||||
|
self.vllm_config.compilation_config.cudagraph_mode,
|
||||||
|
self.vllm_config.compilation_config.cudagraph_capture_sizes,
|
||||||
|
)
|
||||||
self.last_prompt_throughput: float = 0.0
|
self.last_prompt_throughput: float = 0.0
|
||||||
self.last_generation_throughput: float = 0.0
|
self.last_generation_throughput: float = 0.0
|
||||||
self.engine_is_idle = False
|
self.engine_is_idle = False
|
||||||
@@ -161,6 +168,11 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
|
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
|
||||||
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
if kv_connector_stats := scheduler_stats.kv_connector_stats:
|
||||||
self.kv_connector_logging.observe(kv_connector_stats)
|
self.kv_connector_logging.observe(kv_connector_stats)
|
||||||
|
if (
|
||||||
|
self.cudagraph_logging is not None
|
||||||
|
and scheduler_stats.cudagraph_stats is not None
|
||||||
|
):
|
||||||
|
self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats)
|
||||||
if not self.aggregated:
|
if not self.aggregated:
|
||||||
self.last_scheduler_stats = scheduler_stats
|
self.last_scheduler_stats = scheduler_stats
|
||||||
if mm_cache_stats:
|
if mm_cache_stats:
|
||||||
@@ -240,6 +252,8 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
|
|
||||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||||
self.kv_connector_logging.log(log_fn=log_fn)
|
self.kv_connector_logging.log(log_fn=log_fn)
|
||||||
|
if self.cudagraph_logging is not None:
|
||||||
|
self.cudagraph_logging.log(log_fn=log_fn)
|
||||||
|
|
||||||
def log_engine_initialized(self):
|
def log_engine_initialized(self):
|
||||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.cuda_graph import CUDAGraphStat
|
||||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -183,6 +184,8 @@ class SchedulerStats:
|
|||||||
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
|
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
|
||||||
running_lora_adapters: dict[str, int] = field(default_factory=dict)
|
running_lora_adapters: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
cudagraph_stats: CUDAGraphStat | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestStateStats:
|
class RequestStateStats:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.compilation.cuda_graph import CUDAGraphStat
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -169,6 +170,9 @@ class ModelRunnerOutput:
|
|||||||
# req_id -> num_nans_in_logits
|
# req_id -> num_nans_in_logits
|
||||||
num_nans_in_logits: dict[str, int] | None = None
|
num_nans_in_logits: dict[str, int] | None = None
|
||||||
|
|
||||||
|
# information related to cudagraph execution
|
||||||
|
cudagraph_stats: CUDAGraphStat | None = None
|
||||||
|
|
||||||
|
|
||||||
# ModelRunnerOutput wrapper for async scheduling.
|
# ModelRunnerOutput wrapper for async scheduling.
|
||||||
class AsyncModelRunnerOutput(ABC):
|
class AsyncModelRunnerOutput(ABC):
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from vllm.attention.backends.abstract import (
|
|||||||
)
|
)
|
||||||
from vllm.attention.layer import Attention, MLAAttention
|
from vllm.attention.layer import Attention, MLAAttention
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
|
||||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CompilationMode,
|
CompilationMode,
|
||||||
@@ -257,6 +257,7 @@ class ExecuteModelState(NamedTuple):
|
|||||||
sample_hidden_states: torch.Tensor
|
sample_hidden_states: torch.Tensor
|
||||||
aux_hidden_states: list[torch.Tensor] | None
|
aux_hidden_states: list[torch.Tensor] | None
|
||||||
ec_connector_output: ECConnectorOutput | None
|
ec_connector_output: ECConnectorOutput | None
|
||||||
|
cudagraph_stats: CUDAGraphStat | None
|
||||||
|
|
||||||
|
|
||||||
class GPUModelRunner(
|
class GPUModelRunner(
|
||||||
@@ -2755,7 +2756,11 @@ class GPUModelRunner(
|
|||||||
force_uniform_decode: bool | None = None,
|
force_uniform_decode: bool | None = None,
|
||||||
force_has_lora: bool | None = None,
|
force_has_lora: bool | None = None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
|
CUDAGraphMode,
|
||||||
|
BatchDescriptor,
|
||||||
|
UBatchSlices | None,
|
||||||
|
torch.Tensor | None,
|
||||||
|
CUDAGraphStat | None,
|
||||||
]:
|
]:
|
||||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||||
uniform_decode = (
|
uniform_decode = (
|
||||||
@@ -2820,7 +2825,22 @@ class GPUModelRunner(
|
|||||||
# num_tokens_across_dp will no-longer be valid
|
# num_tokens_across_dp will no-longer be valid
|
||||||
assert batch_descriptor.num_tokens == num_tokens_padded
|
assert batch_descriptor.num_tokens == num_tokens_padded
|
||||||
|
|
||||||
return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
|
cudagraph_stats = None
|
||||||
|
if self.vllm_config.observability_config.cudagraph_metrics:
|
||||||
|
cudagraph_stats = CUDAGraphStat(
|
||||||
|
num_unpadded_tokens=num_tokens,
|
||||||
|
num_padded_tokens=batch_descriptor.num_tokens,
|
||||||
|
num_paddings=batch_descriptor.num_tokens - num_tokens,
|
||||||
|
runtime_mode=str(cudagraph_mode),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
cudagraph_mode,
|
||||||
|
batch_descriptor,
|
||||||
|
ubatch_slices,
|
||||||
|
num_tokens_across_dp,
|
||||||
|
cudagraph_stats,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@@ -2918,6 +2938,7 @@ class GPUModelRunner(
|
|||||||
batch_desc,
|
batch_desc,
|
||||||
ubatch_slices,
|
ubatch_slices,
|
||||||
num_tokens_across_dp,
|
num_tokens_across_dp,
|
||||||
|
cudagraph_stats,
|
||||||
) = self._determine_batch_execution_and_padding(
|
) = self._determine_batch_execution_and_padding(
|
||||||
num_tokens=num_tokens_unpadded,
|
num_tokens=num_tokens_unpadded,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
@@ -3067,6 +3088,7 @@ class GPUModelRunner(
|
|||||||
sample_hidden_states,
|
sample_hidden_states,
|
||||||
aux_hidden_states,
|
aux_hidden_states,
|
||||||
ec_connector_output,
|
ec_connector_output,
|
||||||
|
cudagraph_stats,
|
||||||
)
|
)
|
||||||
self.kv_connector_output = kv_connector_output
|
self.kv_connector_output = kv_connector_output
|
||||||
return None
|
return None
|
||||||
@@ -3102,6 +3124,7 @@ class GPUModelRunner(
|
|||||||
sample_hidden_states,
|
sample_hidden_states,
|
||||||
aux_hidden_states,
|
aux_hidden_states,
|
||||||
ec_connector_output,
|
ec_connector_output,
|
||||||
|
cudagraph_stats,
|
||||||
) = self.execute_model_state
|
) = self.execute_model_state
|
||||||
# Clear ephemeral state.
|
# Clear ephemeral state.
|
||||||
self.execute_model_state = None
|
self.execute_model_state = None
|
||||||
@@ -3217,6 +3240,7 @@ class GPUModelRunner(
|
|||||||
if self.supports_mm_inputs
|
if self.supports_mm_inputs
|
||||||
else None,
|
else None,
|
||||||
num_nans_in_logits=num_nans_in_logits,
|
num_nans_in_logits=num_nans_in_logits,
|
||||||
|
cudagraph_stats=cudagraph_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.use_async_scheduling:
|
if not self.use_async_scheduling:
|
||||||
@@ -3937,7 +3961,7 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
|
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
|
||||||
self._determine_batch_execution_and_padding(
|
self._determine_batch_execution_and_padding(
|
||||||
num_tokens=num_tokens_unpadded,
|
num_tokens=num_tokens_unpadded,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
|
|||||||
@@ -564,7 +564,7 @@ class Worker(WorkerBase):
|
|||||||
# TODO(lucas): This is pretty gross; ideally we should only ever call
|
# TODO(lucas): This is pretty gross; ideally we should only ever call
|
||||||
# `_determine_batch_execution_and_padding` once (will get called again
|
# `_determine_batch_execution_and_padding` once (will get called again
|
||||||
# in `execute_model`) but this requires a larger refactor of PP.
|
# in `execute_model`) but this requires a larger refactor of PP.
|
||||||
_, batch_desc, _, _ = (
|
_, batch_desc, _, _, _ = (
|
||||||
self.model_runner._determine_batch_execution_and_padding(
|
self.model_runner._determine_batch_execution_and_padding(
|
||||||
num_tokens=num_scheduled_tokens,
|
num_tokens=num_scheduled_tokens,
|
||||||
num_reqs=len(num_scheduled_tokens_np),
|
num_reqs=len(num_scheduled_tokens_np),
|
||||||
|
|||||||
Reference in New Issue
Block a user