[Bugfix gpt-oss] Fix float32 convert for flashinfer sink support (#23016)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-08-16 14:16:00 -04:00
committed by GitHub
parent 68373d3126
commit 000cceca8c
2 changed files with 9 additions and 3 deletions

View File

@@ -642,9 +642,6 @@ class FlashInferImpl(AttentionImpl):
f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}."
)
# Cast sinks to float32 if needed (FlashInfer requirement)
if sinks.dtype != torch.float32:
sinks = sinks.to(torch.float32)
self.sinks = sinks
def forward(