diff --git a/tests/unit/test_part_a_decode_diagnostics.py b/tests/unit/test_part_a_decode_diagnostics.py index 67769d8a..4b0e6109 100644 --- a/tests/unit/test_part_a_decode_diagnostics.py +++ b/tests/unit/test_part_a_decode_diagnostics.py @@ -42,7 +42,7 @@ def main(): from single_shot_inference import ( load_all_weights, make_nvfp4_linear, get_nvfp4_weight, rmsnorm, unweighted_rmsnorm, _apply_rope, build_rope_cache, - KVCache, Compressor, Indexer, forward_layer, moe_forward, + KVCache, Compressor, Indexer, forward_layer, forward_attention, moe_forward, _load_moe_weights_stacked, _load_shared_expert_weights, _cache_layer_weights_no_experts, )