From f296a1966dca96cd69e5c1fa1264edbf611a1bd6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 13 Mar 2026 07:09:39 +0100 Subject: [PATCH] [Bugfix] Fix FlashInfer GDN warmup ValueError on SM90 GPUs (#36876) --- vllm/model_executor/models/qwen3_next.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 451b332ed..cfd4c7a56 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -137,7 +137,7 @@ def fi_chunk_gated_delta_rule( fi_state = initial_state.to(torch.float32) fi_g = g.to(torch.float32) fi_beta = beta.to(torch.float32) - output, final_state = chunk_gated_delta_rule_fi( + result = chunk_gated_delta_rule_fi( q=q, k=k, v=v, @@ -147,8 +147,14 @@ def fi_chunk_gated_delta_rule( output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) + # FlashInfer returns (output, state) when output_final_state=True, + # or just output when output_final_state=False. # Unsqueeze back to 4D (1, L, H, D) to match fla output format - return output.unsqueeze(0), final_state + if output_final_state: + output, final_state = result + return output.unsqueeze(0), final_state + else: + return result.unsqueeze(0), None @CustomOp.register("chunk_gated_delta_rule")