[Attention] Support multiple attention metadata builders per kv_cache_spec + proper local attention no hybrid kv cache fix (#21588)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from torch import nn
|
||||
from transformers import Llama4TextConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -194,17 +195,18 @@ class Llama4Attention(nn.Module):
|
||||
is_neox_style=is_neox_style,
|
||||
) if not self.nope else None
|
||||
|
||||
self.attn = Attention(
|
||||
attn_cls = Attention if self.nope else ChunkedLocalAttention
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=None,
|
||||
use_irope=not self.nope,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
**({
|
||||
"attention_chunk_size": config.attention_chunk_size
|
||||
} if not self.nope else {}))
|
||||
|
||||
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||
|
||||
Reference in New Issue
Block a user