[Misc] Fix flashinfer related tests (#33462)
Signed-off-by: esmeetu <jasonailu87@gmail.com>
This commit is contained in:
@@ -154,9 +154,10 @@ def convert_to_nvfp4_linear_kernel_format(
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
elif (
|
||||
backend == NvFp4LinearBackend.VLLM_CUTLASS
|
||||
or backend == NvFp4LinearBackend.FLASHINFER_CUTLASS
|
||||
elif backend in (
|
||||
NvFp4LinearBackend.VLLM_CUTLASS,
|
||||
NvFp4LinearBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4LinearBackend.FLASHINFER_CUDNN,
|
||||
):
|
||||
weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass(
|
||||
layer.weight.data, layer.weight_scale.data
|
||||
|
||||
Reference in New Issue
Block a user