[Hardware/NVIDIA/Modelopt] Fix modelopt forward method for v1 torch.compile (#18101)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-05-13 19:33:00 -07:00
committed by GitHub
parent 176a95c670
commit 65f0f74b66
2 changed files with 14 additions and 10 deletions

View File

@@ -401,6 +401,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
if self.use_marlin:
prepare_fp4_layer_for_marlin(layer)
@@ -426,11 +427,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
bias=bias)
output_dtype = x.dtype
# for input only the contracting dimension has a constraint.
x_m, _ = x.shape
w_n, _ = layer.weight.shape
output_shape = [x_m, w_n]
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant = 1 / layer.input_scale
@@ -586,11 +583,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
# GEMM 1
assert torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
"Expected w1_weight_scale_2 to equal w3_weight_scale_2")
"w1_weight_scale_2 must match w3_weight_scale_2")
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
@@ -616,6 +613,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
layer.w13_weight = Parameter(layer.w13_weight.data,
requires_grad=False)
# GEMM 2
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
@@ -633,6 +633,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
@@ -694,7 +695,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
assert not apply_router_weight_on_input, (
"Router weight on input is not "
"supported for ModelOptNvFp4FusedMoE.")
assert expert_map is None, ("Expert Parallelism /expert_map "
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE.")