[Model] Support Step1 Model (#32511)

Signed-off-by: xieli <xieli@stepfun.com>
This commit is contained in:
Li Xie
2026-01-18 18:20:46 +08:00
committed by GitHub
parent fe36bf5e80
commit c826c72a96
9 changed files with 472 additions and 6 deletions

View File

@@ -172,6 +172,10 @@ class AttentionBackend(ABC):
def supports_sink(cls) -> bool:
return False
@classmethod
def supports_alibi_sqrt(cls) -> bool:
return False
@classmethod
def supports_mm_prefix(cls) -> bool:
return False

View File

@@ -331,6 +331,10 @@ class TritonAttentionBackend(AttentionBackend):
AttentionType.ENCODER_DECODER,
)
@classmethod
def supports_alibi_sqrt(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
@@ -353,6 +357,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
use_alibi_sqrt: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@@ -386,7 +391,7 @@ class TritonAttentionImpl(AttentionImpl):
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}."
)
self.use_alibi_sqrt = use_alibi_sqrt
self.supports_quant_query_input = current_platform.is_cuda()
def forward(
@@ -513,6 +518,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=self.use_alibi_sqrt,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,

View File

@@ -82,6 +82,7 @@ def kernel_unified_attention_2d(
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_ALIBI_SQRT: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
@@ -325,7 +326,16 @@ def kernel_unified_attention_2d(
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
if USE_ALIBI_SQRT:
relative_pos = seq_offset - (context_len + query_pos[:, None])
alibi_offset = tl.where(
relative_pos <= 0,
-tl.sqrt((-relative_pos).to(tl.float32)),
0.0,
)
else:
alibi_offset = seq_offset - context_len
S += alibi_slope[:, None] * alibi_offset
if USE_QQ_BIAS:
# compute key positions relative to query section
@@ -420,6 +430,7 @@ def kernel_unified_attention_3d(
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_ALIBI_SQRT: tl.constexpr, # bool
USE_QQ_BIAS: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
@@ -669,7 +680,16 @@ def kernel_unified_attention_3d(
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
if USE_ALIBI_SQRT:
relative_pos = seq_offset - (context_len + query_pos[:, None])
alibi_offset = tl.where(
relative_pos <= 0,
-tl.sqrt((-relative_pos).to(tl.float32)),
0.0,
)
else:
alibi_offset = seq_offset - context_len
S += alibi_slope[:, None] * alibi_offset
if USE_QQ_BIAS:
# compute key positions relative to query section
@@ -888,6 +908,7 @@ def unified_attention(
sinks=None,
# Optional tensor for prefix lengths (PrefixLM support)
mm_prefix_range=None,
use_alibi_sqrt=False,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
@@ -994,6 +1015,7 @@ def unified_attention(
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
@@ -1045,6 +1067,7 @@ def unified_attention(
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_ALIBI_SQRT=use_alibi_sqrt,
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),