[Feature] Migrate DeepGEMM API from get_m_alignment_for_contiguous_layout to get_mk_alignment_for_contiguous_layout (#26935)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -22,13 +22,13 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import (
|
||||||
|
get_mk_alignment_for_contiguous_layout,
|
||||||
|
is_deep_gemm_e8m0_used,
|
||||||
|
)
|
||||||
|
|
||||||
dg_available = has_deep_gemm()
|
dg_available = has_deep_gemm()
|
||||||
|
|
||||||
if dg_available:
|
|
||||||
from deep_gemm import get_m_alignment_for_contiguous_layout
|
|
||||||
|
|
||||||
if current_platform.get_device_capability() < (9, 0):
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||||
|
|
||||||
@@ -218,8 +218,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
|
|||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||||
block_m = get_m_alignment_for_contiguous_layout()
|
block_size = get_mk_alignment_for_contiguous_layout()
|
||||||
block_size = [block_m, block_m]
|
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
|
|||||||
@@ -6,14 +6,17 @@ import torch
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate,
|
TopKWeightAndReduceDelegate,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import (
|
||||||
|
fp8_m_grouped_gemm_nt_masked,
|
||||||
|
get_mk_alignment_for_contiguous_layout,
|
||||||
|
is_deep_gemm_e8m0_used,
|
||||||
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -227,7 +230,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
quant_config: Quantization configuration
|
quant_config: Quantization configuration
|
||||||
"""
|
"""
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
assert self.block_shape == deep_gemm_block_shape()
|
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.num_dispatchers = num_dispatchers
|
self.num_dispatchers = num_dispatchers
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
|||||||
BatchedDeepGemmExperts,
|
BatchedDeepGemmExperts,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||||
|
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
|
||||||
|
|
||||||
|
|
||||||
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
@@ -31,7 +31,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.allow_deep_gemm = (
|
self.allow_deep_gemm = (
|
||||||
allow_deep_gemm
|
allow_deep_gemm
|
||||||
and self.quant_config.use_fp8_w8a8
|
and self.quant_config.use_fp8_w8a8
|
||||||
and self.block_shape == deep_gemm_block_shape()
|
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.batched_deep_gemm_experts = (
|
self.batched_deep_gemm_experts = (
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||||
compute_aligned_M,
|
compute_aligned_M,
|
||||||
deep_gemm_block_shape,
|
|
||||||
deepgemm_moe_permute,
|
deepgemm_moe_permute,
|
||||||
deepgemm_unpermute_and_reduce,
|
deepgemm_unpermute_and_reduce,
|
||||||
)
|
)
|
||||||
@@ -28,14 +27,17 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
from vllm.utils.deep_gemm import (
|
||||||
|
get_mk_alignment_for_contiguous_layout,
|
||||||
|
m_grouped_fp8_gemm_nt_contiguous,
|
||||||
|
)
|
||||||
from vllm.utils.functools import run_once
|
from vllm.utils.functools import run_once
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
|
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
|
||||||
align = deep_gemm_block_shape()[0]
|
align = get_mk_alignment_for_contiguous_layout()[0]
|
||||||
return align <= M and N % align == 0 and K % align == 0
|
return align <= M and N % align == 0 and K % align == 0
|
||||||
|
|
||||||
|
|
||||||
@@ -54,7 +56,7 @@ def _valid_deep_gemm(
|
|||||||
M = hidden_states.size(0)
|
M = hidden_states.size(0)
|
||||||
_, K, N = w2.size()
|
_, K, N = w2.size()
|
||||||
|
|
||||||
align = deep_gemm_block_shape()[0]
|
align = get_mk_alignment_for_contiguous_layout()[0]
|
||||||
|
|
||||||
if not _valid_deep_gemm_shape(M, N, K):
|
if not _valid_deep_gemm_shape(M, N, K):
|
||||||
logger.debug_once(
|
logger.debug_once(
|
||||||
@@ -124,7 +126,7 @@ def warmup_deepgemm_gg_contiguous_kernels(
|
|||||||
|
|
||||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||||
|
|
||||||
block_m = deep_gemm_block_shape()[0]
|
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||||
num_experts = w1.size(0)
|
num_experts = w1.size(0)
|
||||||
device = w1.device
|
device = w1.device
|
||||||
|
|
||||||
@@ -173,7 +175,7 @@ def warmup_deepgemm_gg_contiguous_kernels(
|
|||||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
assert quant_config.block_shape == deep_gemm_block_shape()
|
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||||
assert quant_config.quant_dtype == torch.float8_e4m3fn
|
assert quant_config.quant_dtype == torch.float8_e4m3fn
|
||||||
assert not quant_config.per_act_token_quant
|
assert not quant_config.per_act_token_quant
|
||||||
assert not quant_config.per_out_ch_quant
|
assert not quant_config.per_out_ch_quant
|
||||||
@@ -255,7 +257,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
M=topk_ids.size(0),
|
M=topk_ids.size(0),
|
||||||
num_topk=topk_ids.size(1),
|
num_topk=topk_ids.size(1),
|
||||||
local_num_experts=local_num_experts,
|
local_num_experts=local_num_experts,
|
||||||
alignment=deep_gemm_block_shape()[0],
|
alignment=get_mk_alignment_for_contiguous_layout()[0],
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -364,7 +366,7 @@ def deep_gemm_moe_fp8(
|
|||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=deep_gemm_block_shape(),
|
block_shape=get_mk_alignment_for_contiguous_layout(),
|
||||||
)
|
)
|
||||||
|
|
||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEModularKernel(
|
||||||
|
|||||||
@@ -5,23 +5,13 @@ Taken from https://github.com/ModelTC/LightLLM/blob/8ed97c74c18f11505b048b1ba00b
|
|||||||
and updated to fit vllm needs and terminology.
|
and updated to fit vllm needs and terminology.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import functools
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
|
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def deep_gemm_block_shape() -> list[int]:
|
|
||||||
# Lazy import to avoid CUDA initialization problems.
|
|
||||||
import deep_gemm as dg
|
|
||||||
|
|
||||||
block = dg.get_m_alignment_for_contiguous_layout()
|
|
||||||
return [block, block]
|
|
||||||
|
|
||||||
|
|
||||||
def expert_num_tokens_round_up_and_sum(
|
def expert_num_tokens_round_up_and_sum(
|
||||||
@@ -354,8 +344,7 @@ def deepgemm_moe_permute(
|
|||||||
H = aq.size(1)
|
H = aq.size(1)
|
||||||
device = aq.device
|
device = aq.device
|
||||||
|
|
||||||
block_m = deep_gemm_block_shape()[0]
|
block_m, block_k = get_mk_alignment_for_contiguous_layout()
|
||||||
block_k = deep_gemm_block_shape()[1]
|
|
||||||
|
|
||||||
M_sum = compute_aligned_M(
|
M_sum = compute_aligned_M(
|
||||||
M=topk_ids.size(0),
|
M=topk_ids.size(0),
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
|||||||
_valid_deep_gemm,
|
_valid_deep_gemm,
|
||||||
_valid_deep_gemm_shape,
|
_valid_deep_gemm_shape,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import (
|
||||||
|
get_mk_alignment_for_contiguous_layout,
|
||||||
|
is_deep_gemm_e8m0_used,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
@@ -28,7 +30,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.allow_deep_gemm = (
|
self.allow_deep_gemm = (
|
||||||
allow_deep_gemm
|
allow_deep_gemm
|
||||||
and self.quant_config.use_fp8_w8a8
|
and self.quant_config.use_fp8_w8a8
|
||||||
and self.block_shape == deep_gemm_block_shape()
|
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deep_gemm_expert = (
|
self.deep_gemm_expert = (
|
||||||
|
|||||||
@@ -12,10 +12,7 @@ from tqdm import tqdm
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M
|
||||||
compute_aligned_M,
|
|
||||||
deep_gemm_block_shape,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||||
@@ -23,7 +20,11 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.linear import LinearBase
|
from vllm.model_executor.layers.linear import LinearBase
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous
|
from vllm.utils.deep_gemm import (
|
||||||
|
fp8_gemm_nt,
|
||||||
|
get_mk_alignment_for_contiguous_layout,
|
||||||
|
m_grouped_fp8_gemm_nt_contiguous,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_optimal_warmup_m_values(
|
def _generate_optimal_warmup_m_values(
|
||||||
@@ -129,7 +130,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
|||||||
"""
|
"""
|
||||||
Return True if the input module/layer could be processed with DeepGEMM.
|
Return True if the input module/layer could be processed with DeepGEMM.
|
||||||
"""
|
"""
|
||||||
block_size = deep_gemm_block_shape()[0]
|
block_size = get_mk_alignment_for_contiguous_layout()[0]
|
||||||
if not (
|
if not (
|
||||||
isinstance(module, LinearBase)
|
isinstance(module, LinearBase)
|
||||||
and isinstance(module.quant_method, Fp8LinearMethod)
|
and isinstance(module.quant_method, Fp8LinearMethod)
|
||||||
@@ -139,7 +140,7 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
|||||||
|
|
||||||
w, _, block_sizes = _extract_data_from_linear_base_module(module)
|
w, _, block_sizes = _extract_data_from_linear_base_module(module)
|
||||||
return (
|
return (
|
||||||
block_sizes == deep_gemm_block_shape()
|
block_sizes == get_mk_alignment_for_contiguous_layout()
|
||||||
and w.ndim == 2
|
and w.ndim == 2
|
||||||
and w.shape[0] % block_size == 0
|
and w.shape[0] % block_size == 0
|
||||||
and w.shape[1] % block_size == 0
|
and w.shape[1] % block_size == 0
|
||||||
@@ -155,7 +156,7 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
|||||||
if (
|
if (
|
||||||
moe_quant_config is None
|
moe_quant_config is None
|
||||||
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
|
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
|
||||||
or moe_quant_config.block_shape != deep_gemm_block_shape()
|
or moe_quant_config.block_shape != get_mk_alignment_for_contiguous_layout()
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -176,7 +177,7 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens:
|
|||||||
return
|
return
|
||||||
|
|
||||||
n, k = w.size()
|
n, k = w.size()
|
||||||
block_m = deep_gemm_block_shape()[0]
|
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||||
|
|
||||||
device = w.device
|
device = w.device
|
||||||
a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn)
|
a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn)
|
||||||
@@ -229,7 +230,7 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(
|
|||||||
|
|
||||||
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts"
|
||||||
|
|
||||||
block_m = deep_gemm_block_shape()[0]
|
block_m = get_mk_alignment_for_contiguous_layout()[0]
|
||||||
num_experts = w1.size(0)
|
num_experts = w1.size(0)
|
||||||
device = w1.device
|
device = w1.device
|
||||||
|
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ _fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
|||||||
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||||
|
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
def _lazy_init() -> None:
|
def _lazy_init() -> None:
|
||||||
@@ -83,7 +84,7 @@ def _lazy_init() -> None:
|
|||||||
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
||||||
global _get_paged_mqa_logits_metadata_impl
|
global _get_paged_mqa_logits_metadata_impl
|
||||||
global _get_mn_major_tma_aligned_tensor_impl
|
global _get_mn_major_tma_aligned_tensor_impl
|
||||||
|
global _get_mk_alignment_for_contiguous_layout_impl
|
||||||
# fast path
|
# fast path
|
||||||
if (
|
if (
|
||||||
_fp8_gemm_nt_impl is not None
|
_fp8_gemm_nt_impl is not None
|
||||||
@@ -92,6 +93,7 @@ def _lazy_init() -> None:
|
|||||||
or _fp8_mqa_logits_impl is not None
|
or _fp8_mqa_logits_impl is not None
|
||||||
or _fp8_paged_mqa_logits_impl is not None
|
or _fp8_paged_mqa_logits_impl is not None
|
||||||
or _get_paged_mqa_logits_metadata_impl is not None
|
or _get_paged_mqa_logits_metadata_impl is not None
|
||||||
|
or _get_mk_alignment_for_contiguous_layout_impl is not None
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -118,6 +120,9 @@ def _lazy_init() -> None:
|
|||||||
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
||||||
_dg, "get_mn_major_tma_aligned_tensor", None
|
_dg, "get_mn_major_tma_aligned_tensor", None
|
||||||
)
|
)
|
||||||
|
_get_mk_alignment_for_contiguous_layout_impl = getattr(
|
||||||
|
_dg, "get_mk_alignment_for_contiguous_layout", None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_num_sms() -> int:
|
def get_num_sms() -> int:
|
||||||
@@ -126,6 +131,15 @@ def get_num_sms() -> int:
|
|||||||
return int(_dg.get_num_sms())
|
return int(_dg.get_num_sms())
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_mk_alignment_for_contiguous_layout() -> list[int]:
|
||||||
|
_lazy_init()
|
||||||
|
if _get_mk_alignment_for_contiguous_layout_impl is None:
|
||||||
|
return _missing()
|
||||||
|
mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
|
||||||
|
return [mk_align_size, mk_align_size]
|
||||||
|
|
||||||
|
|
||||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
@@ -338,4 +352,5 @@ __all__ = [
|
|||||||
"get_num_sms",
|
"get_num_sms",
|
||||||
"should_use_deepgemm_for_fp8_linear",
|
"should_use_deepgemm_for_fp8_linear",
|
||||||
"get_col_major_tma_aligned_tensor",
|
"get_col_major_tma_aligned_tensor",
|
||||||
|
"get_mk_alignment_for_contiguous_layout",
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user