diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index de97daccf..d0c13dd49 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -135,7 +135,7 @@ def fi_chunk_gated_delta_rule( fi_state = initial_state.to(torch.float32) fi_g = g.to(torch.float32) fi_beta = beta.to(torch.float32) - return chunk_gated_delta_rule_fi( + output, final_state = chunk_gated_delta_rule_fi( q=q, k=k, v=v, @@ -145,6 +145,8 @@ def fi_chunk_gated_delta_rule( output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) + # Unsqueeze back to 4D (1, L, H, D) to match fla output format + return output.unsqueeze(0), final_state @CustomOp.register("chunk_gated_delta_rule")