[BugFix] EPLB + B200 + DeepGEMM : Handle column-major scales tensor (#29162)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
57430fc95c
commit
3137991f55
@@ -1391,7 +1391,48 @@ class FusedMoE(CustomOp):
|
||||
yield param_name
|
||||
|
||||
def get_expert_weights(self) -> Iterable[torch.Tensor]:
|
||||
def _maybe_make_contiguous(
|
||||
name: str, p: torch.nn.Parameter
|
||||
) -> torch.nn.Parameter:
|
||||
"""
|
||||
In some cases, the last 2 dimensions (the non-expert dimensions)
|
||||
of the weight scale tensor are transposed. This function
|
||||
transforms the tensor (view update) so the tensor is contiguous().
|
||||
Example: A non-contiguous scale tensor,
|
||||
`x` of shape (E, 32, 16) and stride (512, 1, 32) is transformed to
|
||||
`x_` of shape (E, 16, 32) and stride (512, 32, 1).
|
||||
Note that we specifically use torch.transpose() so `x_` refers
|
||||
to the same underlying memory. The tensors `x` and `x_`, pointing
|
||||
to the same underlying memory make this transformation safe in the
|
||||
context of EPLB. i.e. It is the same memory and just the view
|
||||
is different.
|
||||
Note: This function handles the "weight_scale" tensors specifically.
|
||||
This could however be generalized to handle similar tensors.
|
||||
"""
|
||||
if p.ndim != 3:
|
||||
return p
|
||||
if p.is_contiguous():
|
||||
# Already contiguous. do nothing.
|
||||
return p
|
||||
# p is non-contiguous. We only handle the case where the last 2
|
||||
# dimensions of the scales tensor is transposed. We can handle
|
||||
# other cases when they become relevant.
|
||||
is_transposed_12 = p.stride(1) == 1 and p.stride(2) != 1
|
||||
if "weight_scale" not in name or not is_transposed_12:
|
||||
# do nothing.
|
||||
return p
|
||||
|
||||
# Do not update the layer paramater as the layer's MoE operations would
|
||||
# expect the parameter's tensor to the same shape / stride. Instead,
|
||||
# make a new torch.nn.Parameter that is used just in the context of
|
||||
# EPLB.
|
||||
return torch.nn.Parameter(
|
||||
torch.transpose(p.data, 1, 2), requires_grad=False
|
||||
)
|
||||
|
||||
weights = list(self.named_parameters())
|
||||
weights = [(name, _maybe_make_contiguous(name, p)) for name, p in weights]
|
||||
|
||||
assert all(
|
||||
weight.is_contiguous()
|
||||
for name, weight in weights
|
||||
|
||||
Reference in New Issue
Block a user