[ROCm][Kernel] MoE weights padding (#14454)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-03-24 19:45:30 -04:00
committed by GitHub
parent 8279201ce6
commit f533b5837f
5 changed files with 65 additions and 16 deletions

View File

@@ -5,6 +5,7 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
from vllm import envs
@@ -96,9 +97,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
import intel_extension_for_pytorch as ipex