Remove all cases of fmt: on/off (#26253)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -442,14 +442,22 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...]
|
||||
C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
# fmt: off
|
||||
chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
||||
chunk_f = lambda x, i: x[
|
||||
cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ...
|
||||
]
|
||||
|
||||
X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
||||
dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
||||
B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
||||
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||
# fmt: on
|
||||
X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
X, i
|
||||
)
|
||||
dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
dt, i
|
||||
)
|
||||
B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
B, i
|
||||
)
|
||||
C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f(
|
||||
C, i
|
||||
)
|
||||
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)
|
||||
@@ -481,27 +489,42 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
dim=0,
|
||||
)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
# fmt: off
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...]
|
||||
remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
||||
remaining_chunk_f = lambda x, i: x[
|
||||
cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ...
|
||||
]
|
||||
|
||||
remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
||||
remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
||||
remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
||||
remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
||||
remaining_X_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(X, i)
|
||||
remaining_dt_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(dt, i)
|
||||
remaining_B_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(B, i)
|
||||
remaining_C_chunked[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ...
|
||||
] = remaining_chunk_f(C, i)
|
||||
|
||||
# assert input chunking is correct
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
|
||||
pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
||||
pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat(
|
||||
[
|
||||
pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...],
|
||||
pt2[
|
||||
remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1],
|
||||
...,
|
||||
],
|
||||
],
|
||||
dim=0)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501
|
||||
# fmt: on
|
||||
dim=0,
|
||||
)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat(
|
||||
[concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0
|
||||
)
|
||||
|
||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
|
||||
|
||||
Reference in New Issue
Block a user