[V1] Add flag to disable cascade attention (#15243)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -127,6 +127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
@@ -565,11 +566,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
# Prepare for cascade attention if needed.
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
|
||||
Reference in New Issue
Block a user