[ROCm] Fix MoE kernel test failures on gfx950 (#37833)
Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
This commit is contained in:
@@ -32,6 +32,14 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8Static128BlockSym,
|
||||
kFp8StaticChannelSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.utils.import_utils import (
|
||||
has_aiter,
|
||||
has_deep_ep,
|
||||
@@ -152,6 +160,39 @@ class Config:
|
||||
|
||||
return vllm_config, env_dict
|
||||
|
||||
def fe_supports_quant_scheme(self) -> bool:
|
||||
"""Check if the fused experts class supports this quant config.
|
||||
See https://github.com/ROCm/aiter/issues/2419 for AITER gaps."""
|
||||
if self.quant_config is None or self.quant_dtype is None:
|
||||
return True
|
||||
if self.quant_dtype != torch.float8_e4m3fn:
|
||||
return True
|
||||
# Derive QuantKeys from test config
|
||||
if self.quant_block_shape is not None:
|
||||
w_key = kFp8Static128BlockSym
|
||||
a_key = kFp8Dynamic128Sym
|
||||
elif self.is_per_out_ch_quant:
|
||||
w_key = kFp8StaticChannelSym
|
||||
a_key = (
|
||||
kFp8DynamicTokenSym
|
||||
if self.is_per_act_token_quant
|
||||
else kFp8StaticTensorSym
|
||||
)
|
||||
else:
|
||||
w_key = kFp8StaticTensorSym
|
||||
a_key = (
|
||||
kFp8DynamicTensorSym
|
||||
if self.is_per_act_token_quant
|
||||
else kFp8StaticTensorSym
|
||||
)
|
||||
fe_cls = self.fused_experts_type
|
||||
if hasattr(fe_cls, "_supports_quant_scheme"):
|
||||
try:
|
||||
return fe_cls._supports_quant_scheme(w_key, a_key)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return True
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
return (
|
||||
self.quant_dtype == torch.float8_e4m3fn
|
||||
@@ -253,6 +294,15 @@ class Config:
|
||||
f"{self.fe_supported_types()}."
|
||||
)
|
||||
|
||||
# Check quant scheme compatibility with fused experts class
|
||||
if not self.fe_supports_quant_scheme():
|
||||
return False, (
|
||||
f"FE {self.fused_experts_type.__name__} does not support "
|
||||
f"quant scheme (per_out_ch={self.is_per_out_ch_quant}, "
|
||||
f"per_act_token={self.is_per_act_token_quant}, "
|
||||
f"block={self.quant_block_shape})"
|
||||
)
|
||||
|
||||
# Check block quantization support
|
||||
is_block_quantized = self.quant_block_shape is not None
|
||||
if is_block_quantized and self.quant_dtype is None:
|
||||
|
||||
Reference in New Issue
Block a user