[Model] Support Step1 Model (#32511)

Signed-off-by: xieli <xieli@stepfun.com>
This commit is contained in:
Li Xie
2026-01-18 18:20:46 +08:00
committed by GitHub
parent fe36bf5e80
commit c826c72a96
9 changed files with 472 additions and 6 deletions

View File

@@ -331,6 +331,10 @@ class TritonAttentionBackend(AttentionBackend):
AttentionType.ENCODER_DECODER,
)
@classmethod
def supports_alibi_sqrt(cls) -> bool:
return True
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
@@ -353,6 +357,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
sinks: torch.Tensor | None = None,
use_alibi_sqrt: bool = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@@ -386,7 +391,7 @@ class TritonAttentionImpl(AttentionImpl):
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}."
)
self.use_alibi_sqrt = use_alibi_sqrt
self.supports_quant_query_input = current_platform.is_cuda()
def forward(
@@ -513,6 +518,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
use_alibi_sqrt=self.use_alibi_sqrt,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,