[CI] fix mamba kernel test (#26250)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu
2025-10-06 02:26:59 +08:00
committed by GitHub
parent 512b8affa4
commit 9c3c21c519
2 changed files with 12 additions and 1 deletions

View File

@@ -165,7 +165,17 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device)
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=conv_state_indices,
)
out_ref = causal_conv1d_update_ref(
x_ref, conv_state_ref, weight, bias, activation=activation
)