[Model] Support Step1 Model (#32511)
Signed-off-by: xieli <xieli@stepfun.com>
This commit is contained in:
@@ -172,6 +172,10 @@ class AttentionBackend(ABC):
|
||||
def supports_sink(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_alibi_sqrt(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return False
|
||||
|
||||
@@ -331,6 +331,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_alibi_sqrt(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
@@ -353,6 +357,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
use_alibi_sqrt: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@@ -386,7 +391,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
self.use_alibi_sqrt = use_alibi_sqrt
|
||||
self.supports_quant_query_input = current_platform.is_cuda()
|
||||
|
||||
def forward(
|
||||
@@ -513,6 +518,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
use_alibi_sqrt=self.use_alibi_sqrt,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
|
||||
@@ -82,6 +82,7 @@ def kernel_unified_attention_2d(
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_ALIBI_SQRT: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
@@ -325,7 +326,16 @@ def kernel_unified_attention_2d(
|
||||
)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
if USE_ALIBI_SQRT:
|
||||
relative_pos = seq_offset - (context_len + query_pos[:, None])
|
||||
alibi_offset = tl.where(
|
||||
relative_pos <= 0,
|
||||
-tl.sqrt((-relative_pos).to(tl.float32)),
|
||||
0.0,
|
||||
)
|
||||
else:
|
||||
alibi_offset = seq_offset - context_len
|
||||
S += alibi_slope[:, None] * alibi_offset
|
||||
|
||||
if USE_QQ_BIAS:
|
||||
# compute key positions relative to query section
|
||||
@@ -420,6 +430,7 @@ def kernel_unified_attention_3d(
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_ALIBI_SQRT: tl.constexpr, # bool
|
||||
USE_QQ_BIAS: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
@@ -669,7 +680,16 @@ def kernel_unified_attention_3d(
|
||||
)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
if USE_ALIBI_SQRT:
|
||||
relative_pos = seq_offset - (context_len + query_pos[:, None])
|
||||
alibi_offset = tl.where(
|
||||
relative_pos <= 0,
|
||||
-tl.sqrt((-relative_pos).to(tl.float32)),
|
||||
0.0,
|
||||
)
|
||||
else:
|
||||
alibi_offset = seq_offset - context_len
|
||||
S += alibi_slope[:, None] * alibi_offset
|
||||
|
||||
if USE_QQ_BIAS:
|
||||
# compute key positions relative to query section
|
||||
@@ -888,6 +908,7 @@ def unified_attention(
|
||||
sinks=None,
|
||||
# Optional tensor for prefix lengths (PrefixLM support)
|
||||
mm_prefix_range=None,
|
||||
use_alibi_sqrt=False,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
@@ -994,6 +1015,7 @@ def unified_attention(
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_ALIBI_SQRT=use_alibi_sqrt,
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
@@ -1045,6 +1067,7 @@ def unified_attention(
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
USE_ALIBI_SQRT=use_alibi_sqrt,
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
|
||||
Reference in New Issue
Block a user