[CI] Bump mypy version (#34950)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -298,13 +298,13 @@ def test_selective_scan(
|
||||
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
|
||||
C_ref = C.clone()
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
|
||||
D_ref = D.clone()
|
||||
D_ref = D.clone() if D is not None else None
|
||||
z = (
|
||||
torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
|
||||
if has_z
|
||||
else None
|
||||
)
|
||||
z_ref = z.clone() if has_z else None
|
||||
z_ref = z.clone() if z is not None else None
|
||||
delta_bias = (
|
||||
(0.5 * torch.rand(dim, device=device, dtype=torch.float32))
|
||||
if has_delta_bias
|
||||
@@ -493,7 +493,7 @@ def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
|
||||
B[idx : idx + 1],
|
||||
C[idx : idx + 1],
|
||||
D=D,
|
||||
z=z[idx : idx + 1] if has_z else None,
|
||||
z=z[idx : idx + 1] if z is not None else None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
@@ -578,7 +578,7 @@ def test_selective_scan_varlen(
|
||||
C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
|
||||
C_ref = C.clone()
|
||||
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
|
||||
D_ref = D.clone()
|
||||
D_ref = D.clone() if D is not None else None
|
||||
z = torch.randn(dim, seqlen, device=device, dtype=itype)
|
||||
z_ref = z.clone()
|
||||
delta_bias = (
|
||||
@@ -750,7 +750,7 @@ def test_selective_state_update_with_batch_indices(
|
||||
B[:batch_size],
|
||||
C[:batch_size],
|
||||
D=D,
|
||||
z=z[:batch_size],
|
||||
z=z[:batch_size] if z is not None else None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
@@ -934,7 +934,7 @@ def test_selective_state_update_with_num_accepted_tokens(
|
||||
B[global_idx : global_idx + 1],
|
||||
C[global_idx : global_idx + 1],
|
||||
D=D,
|
||||
z=z[global_idx : global_idx + 1] if has_z else None,
|
||||
z=z[global_idx : global_idx + 1] if z is not None else None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
@@ -1061,7 +1061,7 @@ def test_selective_state_update_varlen_with_num_accepted(
|
||||
B[global_idx : global_idx + 1],
|
||||
C[global_idx : global_idx + 1],
|
||||
D=D,
|
||||
z=z[global_idx : global_idx + 1] if has_z else None,
|
||||
z=z[global_idx : global_idx + 1] if z is not None else None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user