[Attention] Clean up iRoPE in V1 (#21188)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -137,6 +137,13 @@ class Attention(nn.Module):
|
|||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
# For v1 we have backend agnostic iRoPE (local chunked attention)
|
||||||
|
# we have to store the flag on the layer so gpu model runner can
|
||||||
|
# set KVSpec appropriately (and pop it so it doesnt get passed to
|
||||||
|
# the backends)
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
self.use_irope = extra_impl_args.pop("use_irope", False)
|
||||||
|
|
||||||
quant_method = quant_config.get_quant_method(
|
quant_method = quant_config.get_quant_method(
|
||||||
self, prefix=prefix) if quant_config else None
|
self, prefix=prefix) if quant_config else None
|
||||||
if quant_method is not None and not isinstance(
|
if quant_method is not None and not isinstance(
|
||||||
|
|||||||
@@ -446,17 +446,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if kv_sharing_target_layer_name is not None:
|
if kv_sharing_target_layer_name is not None:
|
||||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if logits_soft_cap is not None:
|
if logits_soft_cap is not None:
|
||||||
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
logger.warning_once("Torch SPDA does not support logits soft cap. "
|
||||||
"Outputs may be slightly off.")
|
"Outputs may be slightly off.")
|
||||||
if use_irope:
|
|
||||||
logger.warning_once(
|
|
||||||
"Using irope in Torch SPDA is not supported yet, it will fall"
|
|
||||||
" back to global attention for long context.")
|
|
||||||
self.paged_attn_impl = _get_paged_attn_impl()
|
self.paged_attn_impl = _get_paged_attn_impl()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
|||||||
@@ -352,7 +352,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@@ -381,7 +380,6 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashAttentionImpl")
|
"FlashAttentionImpl")
|
||||||
self.use_irope = use_irope
|
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||||
and not flash_attn_supports_fp8():
|
and not flash_attn_supports_fp8():
|
||||||
|
|||||||
@@ -493,7 +493,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[int] = None,
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@@ -509,7 +508,6 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
self.use_irope = use_irope
|
|
||||||
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
|||||||
@@ -148,12 +148,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[int] = None,
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if use_irope:
|
|
||||||
logger.warning_once(
|
|
||||||
"Using irope in Pallas is not supported yet, it will fall back "
|
|
||||||
"to global attention for long context.")
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
|||||||
@@ -337,7 +337,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[int] = None,
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@@ -367,7 +366,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"FlashAttentionImpl")
|
"FlashAttentionImpl")
|
||||||
self.use_irope = use_irope
|
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"AiterFlashAttention does not support fp8 kv-cache on this "
|
"AiterFlashAttention does not support fp8 kv-cache on this "
|
||||||
|
|||||||
@@ -72,9 +72,6 @@ class TritonAttentionMetadataBuilder(
|
|||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
self.headdim = model_config.get_head_size()
|
self.headdim = model_config.get_head_size()
|
||||||
|
|
||||||
self.attention_chunk_size = getattr(vllm_config.scheduler_config,
|
|
||||||
'attention_chunk_size', None)
|
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
def build_for_cudagraph_capture(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
) -> TritonAttentionMetadata:
|
) -> TritonAttentionMetadata:
|
||||||
@@ -208,7 +205,6 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: Optional[int] = None,
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@@ -228,8 +224,6 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
self.use_irope = use_irope
|
|
||||||
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
TritonAttentionBackend.validate_head_size(head_size)
|
TritonAttentionBackend.validate_head_size(head_size)
|
||||||
|
|||||||
@@ -2702,8 +2702,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# TODO: Support other attention modules, e.g., cross-attention
|
# TODO: Support other attention modules, e.g., cross-attention
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
use_local_attention = (self.attention_chunk_size is not None
|
use_local_attention = (self.attention_chunk_size is not None
|
||||||
and getattr(attn_module.impl,
|
and attn_module.use_irope)
|
||||||
"use_irope", False))
|
|
||||||
if attn_module.sliding_window is not None:
|
if attn_module.sliding_window is not None:
|
||||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@@ -2716,13 +2715,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"attention module can not be with ",
|
"attention module can not be with ",
|
||||||
"both local attention and sliding window")
|
"both local attention and sliding window")
|
||||||
elif use_local_attention:
|
elif use_local_attention:
|
||||||
kv_cache_spec[layer_name] = (ChunkedLocalAttentionSpec(
|
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
attention_chunk_size=self.attention_chunk_size,
|
attention_chunk_size=self.attention_chunk_size,
|
||||||
use_mla=use_mla))
|
use_mla=use_mla)
|
||||||
else:
|
else:
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
|
|||||||
@@ -519,6 +519,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
|
if attn_module.use_irope:
|
||||||
|
logger.warning_once(
|
||||||
|
"Using irope in Pallas is not supported yet, it "
|
||||||
|
"will fall back to global attention for long context.")
|
||||||
if attn_module.sliding_window is not None:
|
if attn_module.sliding_window is not None:
|
||||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user