[Bugfix] Fix FlashInfer GDN warmup ValueError on SM90 GPUs (#36876)

This commit is contained in:
Thomas Parnell
2026-03-13 07:09:39 +01:00
committed by GitHub
parent bc2c0c86ef
commit f296a1966d

View File

@@ -137,7 +137,7 @@ def fi_chunk_gated_delta_rule(
fi_state = initial_state.to(torch.float32) fi_state = initial_state.to(torch.float32)
fi_g = g.to(torch.float32) fi_g = g.to(torch.float32)
fi_beta = beta.to(torch.float32) fi_beta = beta.to(torch.float32)
output, final_state = chunk_gated_delta_rule_fi( result = chunk_gated_delta_rule_fi(
q=q, q=q,
k=k, k=k,
v=v, v=v,
@@ -147,8 +147,14 @@ def fi_chunk_gated_delta_rule(
output_final_state=output_final_state, output_final_state=output_final_state,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
) )
# FlashInfer returns (output, state) when output_final_state=True,
# or just output when output_final_state=False.
# Unsqueeze back to 4D (1, L, H, D) to match fla output format # Unsqueeze back to 4D (1, L, H, D) to match fla output format
if output_final_state:
output, final_state = result
return output.unsqueeze(0), final_state return output.unsqueeze(0), final_state
else:
return result.unsqueeze(0), None
@CustomOp.register("chunk_gated_delta_rule") @CustomOp.register("chunk_gated_delta_rule")