[Model] Support Step1 Model (#32511)
Signed-off-by: xieli <xieli@stepfun.com>
This commit is contained in:
@@ -331,6 +331,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def supports_alibi_sqrt(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
@@ -353,6 +357,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
use_alibi_sqrt: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@@ -386,7 +391,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
self.use_alibi_sqrt = use_alibi_sqrt
|
||||
self.supports_quant_query_input = current_platform.is_cuda()
|
||||
|
||||
def forward(
|
||||
@@ -513,6 +518,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
use_alibi_sqrt=self.use_alibi_sqrt,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
|
||||
Reference in New Issue
Block a user