[Bugfix] Fix FI kernelchunk_gated_delta_rule output shape for Qwen3.5 (#34219)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2026-02-10 02:41:24 -08:00
committed by GitHub
parent cbea11c9f0
commit ae4e280602

View File

@@ -135,7 +135,7 @@ def fi_chunk_gated_delta_rule(
fi_state = initial_state.to(torch.float32)
fi_g = g.to(torch.float32)
fi_beta = beta.to(torch.float32)
return chunk_gated_delta_rule_fi(
output, final_state = chunk_gated_delta_rule_fi(
q=q,
k=k,
v=v,
@@ -145,6 +145,8 @@ def fi_chunk_gated_delta_rule(
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
# Unsqueeze back to 4D (1, L, H, D) to match fla output format
return output.unsqueeze(0), final_state
@CustomOp.register("chunk_gated_delta_rule")