[CI] Bump mypy version (#34950)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -94,12 +94,9 @@ def test_rotary_embedding(
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# slice tensor if required, noop otherwise
|
||||
query = query[..., :head_size]
|
||||
key = key[..., :head_size] if use_key else None
|
||||
query = torch.randn(query_shape, dtype=dtype)[..., :head_size]
|
||||
key = torch.randn_like(query)[..., :head_size] if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
|
||||
Reference in New Issue
Block a user