[V1] Add flag to disable cascade attention (#15243)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-03-20 15:24:16 -07:00
committed by GitHub
parent d8e82bc06d
commit 2b22290ce0
3 changed files with 23 additions and 5 deletions

View File

@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
# Prepare for cascade attention if needed.
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
)
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,