[Refactor] Refactor for DeepGemmQuantScaleFMT using cache (#30898)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-12-19 15:50:39 -05:00
committed by GitHub
parent 1ab5213531
commit 3bd8335bd0
2 changed files with 30 additions and 10 deletions

View File

@@ -32,15 +32,34 @@ class DeepGemmQuantScaleFMT(Enum):
# element contains 4 scale values.
UE8M0 = 2
@staticmethod
def from_oracle() -> "DeepGemmQuantScaleFMT":
if not is_deep_gemm_e8m0_used():
return DeepGemmQuantScaleFMT.FLOAT32
return (
DeepGemmQuantScaleFMT.UE8M0
if current_platform.is_device_capability_family(100)
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
@classmethod
def init_oracle_cache(cls) -> None:
"""Initialize the oracle decision and store it in the class cache"""
cached = getattr(cls, "_oracle_cache", None)
if cached is not None:
return
use_e8m0 = (
envs.VLLM_USE_DEEP_GEMM_E8M0
and is_deep_gemm_supported()
and (_fp8_gemm_nt_impl is not None)
)
if not use_e8m0:
cls._oracle_cache = cls.FLOAT32 # type: ignore
return
cls._oracle_cache = ( # type: ignore
cls.UE8M0
if current_platform.is_device_capability_family(100)
else cls.FLOAT32_CEIL_UE8M0
)
@classmethod
def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
"""Return the pre-initialized oracle decision"""
cached = getattr(cls, "_oracle_cache", None)
assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
return cached
@functools.cache
@@ -149,6 +168,7 @@ def _lazy_init() -> None:
_transform_sf_into_required_layout_impl = getattr(
_dg, "transform_sf_into_required_layout", None
)
DeepGemmQuantScaleFMT.init_oracle_cache()
def get_num_sms() -> int: