[Bugfix] Fix FlashInfer GDN warmup ValueError on SM90 GPUs (#36876)
This commit is contained in:
@@ -137,7 +137,7 @@ def fi_chunk_gated_delta_rule(
|
|||||||
fi_state = initial_state.to(torch.float32)
|
fi_state = initial_state.to(torch.float32)
|
||||||
fi_g = g.to(torch.float32)
|
fi_g = g.to(torch.float32)
|
||||||
fi_beta = beta.to(torch.float32)
|
fi_beta = beta.to(torch.float32)
|
||||||
output, final_state = chunk_gated_delta_rule_fi(
|
result = chunk_gated_delta_rule_fi(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
@@ -147,8 +147,14 @@ def fi_chunk_gated_delta_rule(
|
|||||||
output_final_state=output_final_state,
|
output_final_state=output_final_state,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
|
# FlashInfer returns (output, state) when output_final_state=True,
|
||||||
|
# or just output when output_final_state=False.
|
||||||
# Unsqueeze back to 4D (1, L, H, D) to match fla output format
|
# Unsqueeze back to 4D (1, L, H, D) to match fla output format
|
||||||
return output.unsqueeze(0), final_state
|
if output_final_state:
|
||||||
|
output, final_state = result
|
||||||
|
return output.unsqueeze(0), final_state
|
||||||
|
else:
|
||||||
|
return result.unsqueeze(0), None
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("chunk_gated_delta_rule")
|
@CustomOp.register("chunk_gated_delta_rule")
|
||||||
|
|||||||
Reference in New Issue
Block a user