[V1] [Hybrid] Validate compatibility of attention backend batch reordering at init time (#21557)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, ClassVar, Optional, Union
|
from typing import ClassVar, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||||
@@ -21,17 +21,17 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils import cdiv, is_pin_memory_available
|
from vllm.utils import cdiv, is_pin_memory_available
|
||||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
||||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||||
from vllm.v1.attention.backends.utils import (
|
# yapf conflicts with isort for this block
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
# yapf: disable
|
||||||
get_kv_cache_layout, get_per_layer_parameters,
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||||
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
|
AttentionMetadataBuilder,
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
get_kv_cache_layout,
|
||||||
|
get_per_layer_parameters,
|
||||||
|
infer_global_hyperparameters,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
||||||
|
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -179,6 +179,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
AttentionCGSupport.PURE_DECODE_ONLY
|
AttentionCGSupport.PURE_DECODE_ONLY
|
||||||
|
|
||||||
|
reorder_batch_threshold: ClassVar[int] = 1
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -239,12 +241,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: InputBatch,
|
|
||||||
scheduler_output: SchedulerOutput) -> bool:
|
|
||||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
|
||||||
scheduler_output,
|
|
||||||
decode_threshold=1)
|
|
||||||
|
|
||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
self._workspace_buffer = torch.empty(
|
self._workspace_buffer = torch.empty(
|
||||||
|
|||||||
@@ -2,21 +2,17 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
||||||
|
|
||||||
|
|
||||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
@@ -87,6 +83,8 @@ class Mamba2AttentionMetadata:
|
|||||||
class Mamba2AttentionMetadataBuilder(
|
class Mamba2AttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||||
|
|
||||||
|
reorder_batch_threshold: ClassVar[int] = 1
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
vllm_config: VllmConfig, device: torch.device):
|
||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
@@ -95,12 +93,6 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
assert self.chunk_size is not None, (
|
assert self.chunk_size is not None, (
|
||||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
|
||||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
|
||||||
scheduler_output,
|
|
||||||
decode_threshold=1)
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ return curr_o @ W_O
|
|||||||
import functools
|
import functools
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
from typing import ClassVar, Generic, Optional, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -210,10 +210,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
from vllm.utils.flashinfer import has_nvidia_artifactory
|
from vllm.utils.flashinfer import has_nvidia_artifactory
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
get_per_layer_parameters, infer_global_hyperparameters,
|
get_per_layer_parameters,
|
||||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
infer_global_hyperparameters,
|
||||||
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -233,10 +234,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
flashinfer_available = False
|
flashinfer_available = False
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
CUDNN_WORKSPACE_SIZE = 12800
|
CUDNN_WORKSPACE_SIZE = 12800
|
||||||
@@ -403,6 +400,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
NOTE: Please read the comment at the top of the file before trying to
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
understand this class
|
understand this class
|
||||||
"""
|
"""
|
||||||
|
reorder_batch_threshold: ClassVar[int] = 1
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
kv_cache_spec: AttentionSpec,
|
kv_cache_spec: AttentionSpec,
|
||||||
@@ -559,12 +557,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
prefill.prefill_main = self._fi_prefill_main
|
prefill.prefill_main = self._fi_prefill_main
|
||||||
prefill.prefill_chunks = self._fi_prefill_chunks
|
prefill.prefill_chunks = self._fi_prefill_chunks
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
|
||||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
|
||||||
scheduler_output,
|
|
||||||
decode_threshold=1)
|
|
||||||
|
|
||||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
seq_lens: torch.Tensor):
|
seq_lens: torch.Tensor):
|
||||||
return MLACommonDecodeMetadata(
|
return MLACommonDecodeMetadata(
|
||||||
|
|||||||
@@ -251,9 +251,6 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||||
self.total_tokens: int = 0
|
self.total_tokens: int = 0
|
||||||
|
|
||||||
def reorder_batch(self, input_batch, scheduler_output) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
def build_for_cudagraph_capture(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata):
|
self, common_attn_metadata: CommonAttentionMetadata):
|
||||||
self.total_tokens = self.model_config.max_model_len \
|
self.total_tokens = self.model_config.max_model_len \
|
||||||
|
|||||||
@@ -167,6 +167,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
# Does this backend/builder support CUDA Graphs for attention.
|
# Does this backend/builder support CUDA Graphs for attention.
|
||||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
AttentionCGSupport.NEVER
|
AttentionCGSupport.NEVER
|
||||||
|
# Does this backend/builder reorder the batch?
|
||||||
|
# If not, set this to None. Otherwise set it to the query
|
||||||
|
# length that will be pulled into the front of the batch.
|
||||||
|
reorder_batch_threshold: ClassVar[Optional[int]] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
@@ -221,14 +225,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
|
||||||
"""
|
|
||||||
This method can reorder the batch if desired by the backend.
|
|
||||||
:return: Has the batch been reordered (default False).
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def get_kv_cache_layout():
|
def get_kv_cache_layout():
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -9,8 +9,12 @@ import torch.nn as nn
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -27,6 +31,34 @@ class CPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
self._postprocess_tenosrs()
|
self._postprocess_tenosrs()
|
||||||
|
|
||||||
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
|
"""
|
||||||
|
Update the order of requests in the batch based on the attention
|
||||||
|
backend's needs. For example, some attention backends (namely MLA) may
|
||||||
|
want to separate requests based on if the attention computation will be
|
||||||
|
compute-bound or memory-bound.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler_output: The scheduler output.
|
||||||
|
"""
|
||||||
|
# Attention free models have zero kv_cache_goups, however models
|
||||||
|
# like Mamba are also attention free but use the kv_cache for
|
||||||
|
# keeping its internal state. This is why we check the number
|
||||||
|
# of kv_cache groups instead of solely checking
|
||||||
|
# for self.model_config.is_attention_free.
|
||||||
|
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self.kv_cache_config.kv_cache_groups) > 1:
|
||||||
|
raise ValueError("Multiple KVCacheGroups is not"
|
||||||
|
"currently supported with CPU model runner.")
|
||||||
|
|
||||||
|
assert type(
|
||||||
|
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
|
||||||
|
|
||||||
|
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
||||||
|
scheduler_output)
|
||||||
|
|
||||||
def _postprocess_tenosrs(self) -> None:
|
def _postprocess_tenosrs(self) -> None:
|
||||||
# Note: replace device tensors with cpu tensors
|
# Note: replace device tensors with cpu tensors
|
||||||
def replace_tensor(obj: Any, cpu_attr_name: str,
|
def replace_tensor(obj: Any, cpu_attr_name: str,
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
|||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||||
make_kv_sharing_fast_prefill_attention_metadata,
|
make_kv_sharing_fast_prefill_attention_metadata,
|
||||||
make_local_attention_virtual_batches)
|
make_local_attention_virtual_batches,
|
||||||
|
reorder_batch_to_split_decodes_and_prefills)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
ChunkedLocalAttentionSpec,
|
ChunkedLocalAttentionSpec,
|
||||||
@@ -329,6 +330,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
||||||
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
self.reorder_batch_threshold: Optional[int] = None
|
||||||
|
|
||||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""
|
"""
|
||||||
Update the order of requests in the batch based on the attention
|
Update the order of requests in the batch based on the attention
|
||||||
@@ -347,20 +350,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
if self.reorder_batch_threshold is not None:
|
||||||
scheduler_output)
|
reorder_batch_to_split_decodes_and_prefills(
|
||||||
|
self.input_batch,
|
||||||
# For models with multiple KV cache groups, the groups should agree on
|
scheduler_output,
|
||||||
# the same order of requests. We ensure this by only allowing the first
|
decode_threshold=self.reorder_batch_threshold)
|
||||||
# group to reorder the batch and asserting that all other groups do not
|
|
||||||
# reorder the batch.
|
|
||||||
# TODO(tdoublep): make this more flexible so that any group can
|
|
||||||
# re-order the batch (not only the first).
|
|
||||||
# TODO(tdoublep): verify this during engine init instead of at runtime
|
|
||||||
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
|
||||||
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
|
|
||||||
self.input_batch, scheduler_output)
|
|
||||||
assert not batch_reordered
|
|
||||||
|
|
||||||
# Note: used for model runner override.
|
# Note: used for model runner override.
|
||||||
def _init_device_properties(self) -> None:
|
def _init_device_properties(self) -> None:
|
||||||
@@ -2654,6 +2648,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.attn_backends.append(attn_backend_i)
|
self.attn_backends.append(attn_backend_i)
|
||||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||||
|
|
||||||
|
# Calculate reorder batch threshold (if neeeded)
|
||||||
|
self.calculate_reorder_batch_threshold()
|
||||||
|
|
||||||
if len(self.attn_backends) > 0:
|
if len(self.attn_backends) > 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -2688,6 +2685,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.attn_metadata_builders.append(attn_metadata_builder)
|
self.attn_metadata_builders.append(attn_metadata_builder)
|
||||||
self.is_encoder_only_model = True
|
self.is_encoder_only_model = True
|
||||||
|
|
||||||
|
def calculate_reorder_batch_threshold(self) -> None:
|
||||||
|
"""
|
||||||
|
Check that if any backends reorder batches; that the reordering
|
||||||
|
is compatible (e.g., decode threshold is the same)
|
||||||
|
"""
|
||||||
|
for attn_metadata_builder_i in self.attn_metadata_builders:
|
||||||
|
# check that if any backends reorder batches; that the reordering
|
||||||
|
# is compatible (e.g., decode threshold is the same)
|
||||||
|
reorder_batch_threshold_i = (
|
||||||
|
attn_metadata_builder_i.reorder_batch_threshold)
|
||||||
|
if reorder_batch_threshold_i is not None:
|
||||||
|
if self.reorder_batch_threshold is not None:
|
||||||
|
if reorder_batch_threshold_i != \
|
||||||
|
self.reorder_batch_threshold:
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention backend reorders decodes with "
|
||||||
|
f"threshold {reorder_batch_threshold_i} but other "
|
||||||
|
f"backend uses threshold "
|
||||||
|
f"{self.reorder_batch_threshold}")
|
||||||
|
else:
|
||||||
|
self.reorder_batch_threshold = reorder_batch_threshold_i
|
||||||
|
|
||||||
def may_reinitialize_input_batch(self,
|
def may_reinitialize_input_batch(self,
|
||||||
kv_cache_config: KVCacheConfig) -> None:
|
kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user