diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 2fd2ae08a..5ab8496fa 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -452,9 +452,10 @@ th { | `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | | `Qwen3NextForCausalLM` | Qwen3NextMoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | | `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | +| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | | `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | | `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | -| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | +| `Step1ForCausalLM` | Step-Audio | `stepfun-ai/Step-Audio-EditX`, etc. | ✅︎ | ✅︎ | | `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | | `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 0b7d50725..aa0c3dd0b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -472,6 +472,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "ByteDance-Seed/Seed-OSS-36B-Instruct", trust_remote_code=True, ), + "Step1ForCausalLM": _HfExamplesInfo( + "stepfun-ai/Step-Audio-EditX", trust_remote_code=True + ), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 3efa504c7..0e5272d50 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -115,8 +115,11 @@ def can_initialize( # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when # L4 supports FA3. + # Step1ForCausalLM requires TRITON_ATTN for use_alibi_sqrt support. attention_config = ( - {"backend": "TRITON_ATTN"} if model_arch == "GptOssForCausalLM" else None + {"backend": "TRITON_ATTN"} + if model_arch in ("GptOssForCausalLM", "Step1ForCausalLM") + else None ) if model_arch == "WhisperForConditionalGeneration": m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 92795188c..8087dc708 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -162,6 +162,7 @@ class Attention(nn.Module, AttentionLayerBase): scale: float, num_kv_heads: int | None = None, alibi_slopes: list[float] | None = None, + use_alibi_sqrt: bool | None = None, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, logits_soft_cap: float | None = None, @@ -243,7 +244,16 @@ class Attention(nn.Module, AttentionLayerBase): ) else: self.attn_backend = attn_backend - + backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt() + use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False + if use_alibi_sqrt and not backend_supports_alibi_sqrt: + raise ValueError( + f"use_alibi_sqrt is not supported by backend " + f"{self.attn_backend.get_name()}." + ) + self.use_alibi_sqrt = bool(use_alibi_sqrt) + if backend_supports_alibi_sqrt: + extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt # prefix caching + batch invariance is currently not supported for # FLASHINFER and TRITON_MLA. if ( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ab80feea1..fe8c6cf8a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -185,6 +185,7 @@ _TEXT_GENERATION_MODELS = { "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), + "Step1ForCausalLM": ("step1", "Step1ForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), diff --git a/vllm/model_executor/models/step1.py b/vllm/model_executor/models/step1.py new file mode 100644 index 000000000..8e655c691 --- /dev/null +++ b/vllm/model_executor/models/step1.py @@ -0,0 +1,415 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Shared Step decoder blocks and the Step1 text model.""" + +from __future__ import annotations + +import math +from collections.abc import Iterable + +import torch +from torch import nn + +from vllm.attention.layer import Attention, AttentionType +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) +from vllm.sequence import IntermediateTensors + +STEP_PACKED_MODULES_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def _get_step_alibi_slopes(total_num_heads: int) -> torch.Tensor: + """Reference ALiBi slopes used by Step models.""" + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2 ** (-8.0 / closest_power_of_2), + dtype=torch.float32, + ) + slopes = torch.pow( + base, + torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32), + ) + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2 ** (-4.0 / closest_power_of_2), + dtype=torch.float32, + ) + num_remaining_heads = total_num_heads - closest_power_of_2 + extra_powers = torch.arange( + 1, + 1 + 2 * num_remaining_heads, + 2, + dtype=torch.int32, + ) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], + dim=0, + ) + return slopes + + +class StepAttention(nn.Module): + def __init__( + self, + config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + + total_num_kv_heads = getattr( + config, "num_attention_groups", getattr(config, "num_key_value_heads", 1) + ) + if total_num_kv_heads is None or total_num_kv_heads <= 0: + total_num_kv_heads = 1 + self.total_num_kv_heads = total_num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=getattr(config, "attention_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, + bias=getattr(config, "attention_bias", False), + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_step_alibi_slopes(self.total_num_heads)[head_start:head_end] + alibi_slopes = alibi_slopes.tolist() + + self.scale = self.head_dim**-0.5 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + alibi_slopes=alibi_slopes, + prefix=f"{prefix}.attn", + use_alibi_sqrt=True, + attn_type=AttentionType.DECODER, + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class StepMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + bias: bool = False, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size, intermediate_size], + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class StepDecoderLayer(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.hidden_size = config.hidden_size + self.self_attn = StepAttention( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = StepMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + bias=getattr(config, "mlp_bias", False), + ) + self.input_layernorm = RMSNorm( + self.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = RMSNorm( + self.hidden_size, + eps=config.rms_norm_eps, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) # type: ignore[name-defined] + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class StepDecoderModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + # Need embed_tokens on first rank, and also on last rank if tie_word_embeddings + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: StepDecoderLayer(vllm_config=vllm_config, prefix=prefix), + prefix=maybe_prefix(prefix, "layers"), + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.aux_hidden_state_layers: tuple[int, ...] = getattr( + config, "aux_hidden_state_layers", () + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], + config.hidden_size, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + assert input_ids is not None + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states = [] + for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): + if idx in self.aux_hidden_state_layers: + if residual is None: + aux_hidden_states.append(hidden_states) + else: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + if aux_hidden_states: + return hidden_states, aux_hidden_states + return hidden_states + + +class Step1ForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = STEP_PACKED_MODULES_MAPPING + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.model = StepDecoderModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if getattr(config, "tie_word_embeddings", True): + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = None # type: ignore[assignment] + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.LongTensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + return self.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + if not get_pp_group().is_last_rank: + return None + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 6c6bb808b..5ea8f0e62 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -172,6 +172,10 @@ class AttentionBackend(ABC): def supports_sink(cls) -> bool: return False + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return False + @classmethod def supports_mm_prefix(cls) -> bool: return False diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4cc438d9f..06cb17211 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -331,6 +331,10 @@ class TritonAttentionBackend(AttentionBackend): AttentionType.ENCODER_DECODER, ) + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return True + @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return True @@ -353,6 +357,7 @@ class TritonAttentionImpl(AttentionImpl): attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: int | None = None, sinks: torch.Tensor | None = None, + use_alibi_sqrt: bool = False, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -386,7 +391,7 @@ class TritonAttentionImpl(AttentionImpl): f"heads in the layer. Sinks shape: {sinks.shape}, " f"num_heads: {num_heads}." ) - + self.use_alibi_sqrt = use_alibi_sqrt self.supports_quant_query_input = current_platform.is_cuda() def forward( @@ -513,6 +518,7 @@ class TritonAttentionImpl(AttentionImpl): softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, + use_alibi_sqrt=self.use_alibi_sqrt, window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, diff --git a/vllm/v1/attention/ops/triton_unified_attention.py b/vllm/v1/attention/ops/triton_unified_attention.py index 345889969..6855233ee 100644 --- a/vllm/v1/attention/ops/triton_unified_attention.py +++ b/vllm/v1/attention/ops/triton_unified_attention.py @@ -82,6 +82,7 @@ def kernel_unified_attention_2d( HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_ALIBI_SQRT: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool USE_SINKS: tl.constexpr, # bool @@ -325,7 +326,16 @@ def kernel_unified_attention_2d( ) if USE_ALIBI_SLOPES: - S += alibi_slope[:, None] * (seq_offset - context_len) + if USE_ALIBI_SQRT: + relative_pos = seq_offset - (context_len + query_pos[:, None]) + alibi_offset = tl.where( + relative_pos <= 0, + -tl.sqrt((-relative_pos).to(tl.float32)), + 0.0, + ) + else: + alibi_offset = seq_offset - context_len + S += alibi_slope[:, None] * alibi_offset if USE_QQ_BIAS: # compute key positions relative to query section @@ -420,6 +430,7 @@ def kernel_unified_attention_3d( HEAD_SIZE: tl.constexpr, # int HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_ALIBI_SQRT: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool USE_SINKS: tl.constexpr, # bool @@ -669,7 +680,16 @@ def kernel_unified_attention_3d( ) if USE_ALIBI_SLOPES: - S += alibi_slope[:, None] * (seq_offset - context_len) + if USE_ALIBI_SQRT: + relative_pos = seq_offset - (context_len + query_pos[:, None]) + alibi_offset = tl.where( + relative_pos <= 0, + -tl.sqrt((-relative_pos).to(tl.float32)), + 0.0, + ) + else: + alibi_offset = seq_offset - context_len + S += alibi_slope[:, None] * alibi_offset if USE_QQ_BIAS: # compute key positions relative to query section @@ -888,6 +908,7 @@ def unified_attention( sinks=None, # Optional tensor for prefix lengths (PrefixLM support) mm_prefix_range=None, + use_alibi_sqrt=False, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -994,6 +1015,7 @@ def unified_attention( HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, + USE_ALIBI_SQRT=use_alibi_sqrt, USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None), @@ -1045,6 +1067,7 @@ def unified_attention( HEAD_SIZE=head_size, HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), USE_ALIBI_SLOPES=use_alibi_slopes, + USE_ALIBI_SQRT=use_alibi_sqrt, USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None),