[ROCm] Small functional changes for gptoss (#25201)

Signed-off-by: jpvillam <jpvillam@amd.com>
Co-authored-by: jpvillam <jpvillam@amd.com>
This commit is contained in:
Juan Villamizar
2025-09-23 18:39:50 -05:00
committed by GitHub
parent 5e25b12236
commit bde2a1a8a4
3 changed files with 26 additions and 6 deletions

View File

@@ -212,12 +212,15 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
elif current_platform.is_rocm() or (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
hidden_size = round_up(hidden_size, 128)
elif current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)