[Misc] Fix flashinfer related tests (#33462)

Signed-off-by: esmeetu <jasonailu87@gmail.com>
This commit is contained in:
Roy Wang
2026-02-01 05:10:24 +08:00
committed by GitHub
parent 1e86c802d4
commit 63c0889416
5 changed files with 9 additions and 8 deletions

View File

@@ -154,9 +154,10 @@ def convert_to_nvfp4_linear_kernel_format(
)
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
elif (
backend == NvFp4LinearBackend.VLLM_CUTLASS
or backend == NvFp4LinearBackend.FLASHINFER_CUTLASS
elif backend in (
NvFp4LinearBackend.VLLM_CUTLASS,
NvFp4LinearBackend.FLASHINFER_CUTLASS,
NvFp4LinearBackend.FLASHINFER_CUDNN,
):
weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass(
layer.weight.data, layer.weight_scale.data