Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -7,10 +7,10 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen)
mamba_chunk_scan_combined_varlen,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba2_attn import (
compute_varlen_chunk_metadata)
from vllm.v1.attention.backends.mamba2_attn import compute_varlen_chunk_metadata
# Added by the IBM Team, 2024
@@ -22,12 +22,10 @@ def segsum(x):
"""Calculates segment sum."""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=-1)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
diagonal=0)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
@@ -46,8 +44,9 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
for x in (X, A, B, C))
X, A, B, C = (
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
)
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
@@ -74,7 +73,7 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms
# (diagonal and off-diagonal blocks)
@@ -82,42 +81,31 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
return Y, final_state
def generate_random_inputs(batch_size,
seqlen,
n_heads,
d_head,
itype,
device='cuda'):
def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"):
current_platform.seed_everything(0)
A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
dt = F.softplus(
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
4)
X = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4
)
X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
return A, dt, X, B, C
def generate_continuous_batched_examples(example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device='cuda',
return_naive_ref=True):
def generate_continuous_batched_examples(
example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device="cuda",
return_naive_ref=True,
):
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels.
@@ -126,23 +114,20 @@ def generate_continuous_batched_examples(example_lens_by_batch,
# reference output.
# generate the full-length example
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
d_head, itype)
A, dt, X, B, C = generate_random_inputs(
num_examples, full_length, n_heads, d_head, itype
)
if return_naive_ref:
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length //
4)
Y_min, final_state_min = ssd_minimal_discrete(
X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4
)
# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def get_continuous_batch(example_lens: tuple[int, ...]):
indices = []
for i, x in enumerate(example_lens):
c = last_taken.get(i, 0)
@@ -150,8 +135,10 @@ def generate_continuous_batched_examples(example_lens_by_batch,
last_taken[i] = (c + x) % full_length
exhausted[i] = last_taken[i] == 0
return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
]).unsqueeze(0) for x in (dt, X, B, C))
return (
torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0)
for x in (dt, X, B, C)
)
# internal function that maps "n" to the appropriate right boundary
# value when forming continuous batches from examples of length given
@@ -163,19 +150,20 @@ def generate_continuous_batched_examples(example_lens_by_batch,
IND_E = None
for spec in example_lens_by_batch:
# get the (maybe partial) example seen in this cont batch
dt2, X2, B2, C2 = get_continuous_batch(spec)
# get the metadata
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
seq_idx = torch.zeros(cu_seqlens[-1],
dtype=torch.int32,
device=cu_seqlens.device)
for i, (srt, end) in enumerate(zip(
cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0)
seq_idx = torch.zeros(
cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device
)
for i, (srt, end) in enumerate(
zip(
cu_seqlens,
cu_seqlens[1:],
)):
)
):
seq_idx[srt:end] = i
# for cont batch
@@ -190,19 +178,21 @@ def generate_continuous_batched_examples(example_lens_by_batch,
X2 = X2.squeeze(0)
B2 = B2.squeeze(0)
C2 = C2.squeeze(0)
yield ([Y_min[s, IND_S[s]:IND_E[s]]
for s in range(num_examples)] if return_naive_ref else None,
cu_seqlens, seq_idx, (A, dt2, X2, B2, C2))
yield (
[Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)]
if return_naive_ref
else None,
cu_seqlens,
seq_idx,
(A, dt2, X2, B2, C2),
)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
itype):
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
# this tests the kernels on a single example (bs=1)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
@@ -219,15 +209,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# it is not an operational limitation.
seqlen, chunk_size = seq_len_chunk_size
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
d_head, itype)
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype)
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)
Y_min, final_state_min = ssd_minimal_discrete(
X * dt.unsqueeze(-1), A * dt, B, C, chunk_size
)
cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0)
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
)
# varlen has implicit batch=1
X = X.squeeze(0)
dt = dt.squeeze(0)
@@ -255,10 +246,12 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.testing.assert_close(final_state[:, -1].to(torch.float32),
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol)
torch.testing.assert_close(
final_state[:, -1].to(torch.float32),
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol,
)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@@ -267,32 +260,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
(4, 4)]), # chunk_size larger than cont batches
(64, 8, 5, [
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
]), # mode examples with varied lengths
(
64,
8,
2,
[(4, 4), (4, 4), (4, 4), (4, 4)],
), # chunk_size larger than cont batches
(
64,
8,
5,
[
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
],
), # mode examples with varied lengths
# large-ish chunk_size (256)
(64, 256, 1, [(5, ), (1, ), (1, ),
(1, )]), # irregular sizes with small sequences
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
(
64,
256,
2,
[(5, 30), (1, 2), (1, 2), (1, 2)],
), # irregular sizes with small sequences
# we also need to test some large seqlen
# to catch errors with init states decay
(768, 128, 2, [(138, 225), (138, 225)]),
])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype):
],
)
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype):
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
@@ -311,12 +312,17 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
states = None
for Y_min, cu_seqlens, _token_seq_idx, (
A, dt, X, B, C) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
d_head, itype):
A,
dt,
X,
B,
C,
) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype
):
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
)
Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined_varlen(
@@ -337,9 +343,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# just test the last in sequence
for i in range(num_examples):
# just test one dim and dstate
Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_eg = Y[cu_seqlens[i] : cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
@@ -347,18 +352,20 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
states = new_states
for i, clear in exhausted.items():
if clear:
states[i].fill_(0.)
states[i].fill_(0.0)
exhausted[i] = False
@pytest.mark.parametrize("chunk_size", [8, 256])
@pytest.mark.parametrize("seqlens", [
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
])
@pytest.mark.parametrize(
"seqlens",
[
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
],
)
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# This test verifies the correctness of the chunked prefill implementation
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
# dimension) of chunked results with the full sequence result.
@@ -387,21 +394,25 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
generate_continuous_batched_examples([seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False))
generate_continuous_batched_examples(
[seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False,
)
)
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
device = X.device
## full seqlen computation
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
compute_varlen_chunk_metadata(cu_seqlens, chunk_size)
)
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined_varlen(
X,
@@ -422,11 +433,9 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
## chunked seqlen computation
# first chunk
chunked_seqlens = seqlens // 2
chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(chunked_seqlens, dim=0)
],
dim=0)
chunked_cu_seqlens = torch.cat(
[torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0
)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
@@ -443,7 +452,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# fmt: on
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size))
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
)
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined_varlen(
X_chunked,
@@ -463,11 +473,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# remaining chunk
remaining_chunked_seqlens = seqlens - chunked_seqlens
remaining_chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0)
],
dim=0)
remaining_chunked_cu_seqlens = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0),
],
dim=0,
)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
@@ -497,8 +509,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens,
chunk_size))
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens, chunk_size)
)
Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined_varlen(
@@ -520,20 +532,22 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_seq = Y[cu_seqlens[i] : cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[cu_seqlens[i] : cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:chunked_seqlens[i], ...],
Y_ref_seq[:chunked_seqlens[i], ...],
Y_seq[: chunked_seqlens[i], ...],
Y_ref_seq[: chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
msg=lambda x: f"seq{i} output part1 " + x,
) # noqa: B023
torch.testing.assert_close(
Y_seq[chunked_seqlens[i]:, ...],
Y_ref_seq[chunked_seqlens[i]:, ...],
Y_seq[chunked_seqlens[i] :, ...],
Y_ref_seq[chunked_seqlens[i] :, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
msg=lambda x: f"seq{i} output part2 " + x,
) # noqa: B023
state_seq = state_chunked[i]
state_seq_ref = state_ref[i]
@@ -542,4 +556,5 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
state_seq_ref,
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} state " + x) # noqa: B023
msg=lambda x: f"seq{i} state " + x,
) # noqa: B023