[Kernel] some optimizations for dense marlin and moe marlin (#16850)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-06 00:39:30 +08:00
committed by GitHub
parent f62cad6431
commit 1d0c9d6b2d
26 changed files with 3512 additions and 3268 deletions

View File

@@ -6,6 +6,8 @@ from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
_SCALAR_TYPES_ID_MAP = {}
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
@@ -158,6 +160,8 @@ class ScalarType:
assert offset <= 64, \
f"ScalarType fields too big {offset} to fit into an int64"
_SCALAR_TYPES_ID_MAP[val] = self
return val
@property
@@ -295,6 +299,13 @@ class ScalarType:
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def from_id(cls, scalar_type_id: int):
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
raise ValueError(
f"scalar_type_id {scalar_type_id} doesn't exists.")
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is: