[Bugfix] Remove false-positive format mismatch warnings in FLA ops (#38255)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2026-03-30 14:32:26 +02:00
committed by GitHub
parent 6557f4937f
commit 7c3f88b2a8
2 changed files with 0 additions and 17 deletions

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
import torch
@@ -184,13 +183,6 @@ def chunk_gated_delta_rule(
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
)
assert len(beta.shape) == 3, "beta must be of shape [B, T, H]."
if q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
import torch
@@ -252,14 +251,6 @@ def chunk_local_cumsum(
output_dtype: torch.dtype | None = torch.float,
**kwargs,
) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None:
assert g.shape[0] == 1, (
"Only batch size 1 is supported when cu_seqlens are provided"