[Quant] Make static quant support all group shapes (#30833)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-09 15:49:27 -05:00
committed by GitHub
parent f9e2a75a1e
commit 0a0aa07747
7 changed files with 338 additions and 46 deletions

View File

@@ -2,6 +2,7 @@
#include <optional>
#include <torch/library.h>
#include <tuple>
#include "core/scalar_type.hpp"
@@ -346,8 +347,9 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
void static_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);

View File

@@ -4,28 +4,77 @@
#include "quantization/vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
#include <tuple>
namespace vllm {
template <typename scalar_t, typename fp8_type>
__global__ void scaled_fp8_quant_kernel_strided(
// STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel)
// STRIDE_J_ZERO: true if scale_stride_j == 0 (per-tensor or per-token)
template <typename scalar_t, typename fp8_type, bool STRIDE_I_ZERO,
bool STRIDE_J_ZERO>
__global__ void scaled_fp8_quant_kernel_strided_group_shape(
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
int64_t out_row_stride) {
const int64_t token_idx = blockIdx.x; // one token per block
int64_t out_row_stride, int group_m, int group_n, int64_t scale_stride_i,
int64_t scale_stride_j) {
const int64_t token_idx = blockIdx.x;
const int tid = threadIdx.x;
const scalar_t* token_in = input + token_idx * in_row_stride;
fp8_type* token_out = out + token_idx * out_row_stride;
const float inv_scale = 1.0f / (*scale);
// Precompute row-level base offset for scale access (compile-time eliminated
// when STRIDE_I_ZERO)
const int64_t scale_row_base =
STRIDE_I_ZERO ? 0
: static_cast<int>(token_idx) / group_m * scale_stride_i;
vectorize_with_alignment<16>(
token_in, token_out, hidden_size, tid, blockDim.x,
[=] __device__(fp8_type & dst, const scalar_t& src) {
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
inv_scale);
});
auto get_inv_scale = [&](int gj) {
return 1.0f / scale[scale_row_base + gj * scale_stride_j];
};
int cached_gj = -1;
float cached_inv_scale = 0.0f;
auto get_inv_scale_cached = [&](int gj) {
if (gj != cached_gj) {
cached_inv_scale = 1.0f / scale[scale_row_base + gj * scale_stride_j];
cached_gj = gj;
}
return cached_inv_scale;
};
constexpr int VEC_SIZE = 16; // FP8 so vectorize to 128 bits
auto scaled_fp8_conversion_vectorized = [&](const scalar_t* in, fp8_type* out,
int size, float inv_scale) {
vectorize_with_alignment<VEC_SIZE>(
in, out, size, tid, blockDim.x,
[=] __device__(fp8_type & dst, const scalar_t& src) {
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
inv_scale);
});
};
if (STRIDE_J_ZERO && hidden_size % VEC_SIZE == 0) {
// Per-tensor or per-token: single scale per row, vectorize full row
scaled_fp8_conversion_vectorized(token_in, token_out, hidden_size,
get_inv_scale(0));
} else if (group_n % VEC_SIZE == 0) {
// Multiple column groups with vectorization
const int num_groups_n = hidden_size / group_n;
for (int gj = 0; gj < num_groups_n; gj++) {
scaled_fp8_conversion_vectorized(token_in + gj * group_n,
token_out + gj * group_n, group_n,
get_inv_scale(gj));
}
} else {
// Scalar path for small column groups (group_n < VEC_SIZE)
for (int n = tid; n < hidden_size; n += blockDim.x) {
const int gj = n / group_n;
token_out[n] = scaled_fp8_conversion<true, fp8_type>(
static_cast<float>(token_in[n]), get_inv_scale_cached(gj));
}
}
}
template <typename scalar_t, typename fp8_type>
@@ -133,17 +182,116 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
} // namespace vllm
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale, // various shapes
std::optional<std::tuple<int64_t, int64_t>>
opt_group_shape) // optional explicit (group_m, group_n)
{
TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous");
TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous");
const int hidden_size = input.size(-1);
const int num_tokens = input.numel() / hidden_size;
const int hidden_size = input.size(-1); // N (columns)
const int num_tokens = input.numel() / hidden_size; // M (rows)
// Determine group_m, group_n, and scale strides from scale shape
// Scale indexing: scale[gi * scale_stride_j + gj * scale_stride_i]
// where gi = m / group_m, gj = n / group_n
int group_m, group_n;
int64_t scale_stride_i, scale_stride_j;
if (scale.dim() == 0 || scale.numel() == 1) {
// Per-tensor: one scale for the entire tensor
group_m = num_tokens;
group_n = hidden_size;
scale_stride_i = 0;
scale_stride_j = 0;
} else if (scale.dim() == 1) {
// 1D scale: require explicit group_shape to disambiguate per-channel vs
// per-token (avoids edge case where num_tokens == hidden_size)
TORCH_CHECK(opt_group_shape.has_value(),
"1D scale requires explicit group_shape to disambiguate "
"per-channel vs per-token quantization. "
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
"-1) for per-token.");
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
// Validate the explicit group shape matches the 1D scale
const int64_t scale_len = scale.numel();
const int64_t expected_scale_m = num_tokens / group_m;
const int64_t expected_scale_n = hidden_size / group_n;
const int64_t expected_scale_numel = expected_scale_m * expected_scale_n;
TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (",
scale_len, ") does not match expected size (",
expected_scale_numel, ") for group_shape (", opt_group_m, ", ",
opt_group_n, ") with input shape (", num_tokens, ", ",
hidden_size, ")");
// For 1D scale, determine strides based on which dim is trivial
// Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j]
// where gi = m / group_m (row group), gj = n / group_n (col group)
if (expected_scale_m == 1) {
// Per-channel style: one scale in M dim, scale varies along N
// gi = 0 always, gj varies, so stride_1 traverses the scale
scale_stride_i = 0;
scale_stride_j = scale.stride(0);
} else if (expected_scale_n == 1) {
// Per-token style: one scale in N dim, scale varies along M
// gj = 0 always, gi varies, so stride_0 traverses the scale
scale_stride_i = scale.stride(0);
scale_stride_j = 0;
} else {
TORCH_CHECK(
false,
"1D scale can only be used when one of the scale dimensions is 1. "
"For 2D group scaling, use a 2D scale tensor.");
}
} else if (scale.dim() == 2) {
// 2D scale: infer group sizes from scale dimensions (or use explicit if
// provided)
const int64_t scale_size_0 = scale.size(0);
const int64_t scale_size_1 = scale.size(1);
TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens,
") must be divisible by scale.size(0) (", scale_size_0, ")");
TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", hidden_size,
") must be divisible by scale.size(1) (", scale_size_1, ")");
// Infer from 2D scale shape
int inferred_group_m = num_tokens / scale_size_0;
int inferred_group_n = hidden_size / scale_size_1;
// Use explicit if provided, otherwise use inferred
if (opt_group_shape.has_value()) {
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
// Validate explicit matches inferred
TORCH_CHECK(group_m == inferred_group_m && group_n == inferred_group_n,
"Explicit group_shape (", opt_group_m, ", ", opt_group_n,
") does not match inferred group shape (", inferred_group_m,
", ", inferred_group_n, ") from 2D scale tensor shape (",
scale_size_0, ", ", scale_size_1, ")");
} else {
group_m = inferred_group_m;
group_n = inferred_group_n;
}
scale_stride_i = scale.stride(0);
scale_stride_j = scale.stride(1);
} else {
TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ",
scale.dim(), "D");
}
const int block_size = 256;
dim3 grid(num_tokens);
dim3 block(block_size);
@@ -153,15 +301,23 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Dispatch to template-specialized kernel based on stride pattern
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), hidden_size, in_row_stride,
out_row_stride);
VLLM_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] {
VLLM_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] {
vllm::scaled_fp8_quant_kernel_strided_group_shape<
scalar_t, fp8_t, S0_ZERO, S1_ZERO>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), hidden_size, in_row_stride,
out_row_stride, group_m, group_n, scale_stride_i,
scale_stride_j);
});
});
});
});
}

View File

@@ -599,9 +599,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor.
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
// scaling. Optional group_m/group_n specify the group shape explicitly;
// required for 1D scales to disambiguate per-channel vs per-token.
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()");
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
"(int, int)? group_shape=None) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.

View File

@@ -11,6 +11,10 @@ from tests.kernels.quant_utils import (
ref_dynamic_per_token_quant,
)
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
DTYPES = [torch.bfloat16, torch.float]
@@ -21,10 +25,18 @@ SEEDS = [0]
def opcheck_fp8_quant(
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
output,
input,
scale=None,
scale_ub=None,
use_per_token_if_dynamic=False,
group_shape=None,
):
if scale is not None:
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
opcheck(
torch.ops._C.static_scaled_fp8_quant,
(output, input, scale, group_shape),
)
elif use_per_token_if_dynamic:
scale = torch.empty(
(input.shape[0], 1), device=input.device, dtype=torch.float32
@@ -118,3 +130,92 @@ def test_fp8_quant_large(seed: int) -> None:
ops_out = ops_out.to(dtype=dtype)
torch.testing.assert_close(ref_out, ops_out)
# Test static FP8 quantization with 2D group scales
GROUP_SHAPES_2D = [
(-1, -1), # Per-tensor
(-1, 1), # Per-channel
(1, -1), # Per-token
(-1, 128), # Per-head quantization
(1, 128), # DeepSeek-style per-token-per-group (group_m=1, group_n=128)
(128, 128), # DeepSeek-style block quantization
(1, 64), # Smaller group size
(1, 16), # Small group (scalar path in kernel)
(4, 256), # Non-trivial both dimensions
]
# Use sizes divisible by all group shapes
NUM_TOKENS_GROUP = [128, 512]
HIDDEN_SIZES_GROUP = [256, 1024, 2048]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
@pytest.mark.parametrize("group_shape", GROUP_SHAPES_2D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_static_fp8_quant_group_2d(
num_tokens: int,
hidden_size: int,
group_shape: tuple[int, int],
dtype: torch.dtype,
seed: int,
) -> None:
"""Test static FP8 quantization with 2D group scales using scaled_quantize."""
# Normalize group_shape (-1 means full extent)
norm_group_m = num_tokens if group_shape[0] == -1 else group_shape[0]
norm_group_n = hidden_size if group_shape[1] == -1 else group_shape[1]
# Skip if sizes are not divisible by group shape
if num_tokens % norm_group_m != 0 or hidden_size % norm_group_n != 0:
pytest.skip(
f"Skipping: ({num_tokens}, {hidden_size}) not divisible by "
f"group_shape ({group_shape[0]}, {group_shape[1]})"
)
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = scaled_quantize(
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
)
ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape)
torch.testing.assert_close(scale, ops_scale)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
opcheck_fp8_quant(ops_out, x, scale=scale)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("group_shape", [(1, -1), (-1, 1)]) # per-token, per-channel
@torch.inference_mode()
def test_static_fp8_quant_1d_scale(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
seed: int,
group_shape: tuple[int, int],
) -> None:
"""Test static FP8 quantization with 1D scale (per-token or per-channel)."""
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale_2d = scaled_quantize(
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
)
# Flatten scale to 1D for testing 1D scale path
scale_1d = scale_2d.flatten()
ops_out, ops_scale = ops.scaled_fp8_quant(
x, scale=scale_1d, group_shape=group_shape
)
torch.testing.assert_close(scale_1d, ops_scale)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
opcheck_fp8_quant(ops_out, x, scale=scale_1d, group_shape=group_shape)

View File

@@ -1752,6 +1752,7 @@ def scaled_fp8_quant(
scale_ub: torch.Tensor | None = None,
use_per_token_if_dynamic: bool = False,
output: torch.Tensor | None = None,
group_shape: tuple[int, int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
@@ -1763,14 +1764,23 @@ def scaled_fp8_quant(
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
input: The input tensor to be quantized to FP8 (must be 2D: [M, N])
scale: Optional scaling factor for the FP8 quantization. Supports:
- 0D or [1]: per-tensor scaling
- 1D: requires explicit group_shape to disambiguate per-channel
vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token)
- 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for
DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128))
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
group_shape: Optional tuple (group_m, group_n) specifying the group
shape for static quantization. Use -1 for "full extent" (e.g.,
(-1, -1) for per-tensor, (-1, 1) for per-channel, etc.)
Required for 1D scales; optional for 2D scales.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
@@ -1799,8 +1809,7 @@ def scaled_fp8_quant(
scale = torch.empty(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
torch.ops._C.static_scaled_fp8_quant(output, input, scale, group_shape)
return output, scale

View File

@@ -10,6 +10,7 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
group_broadcast,
)
from vllm.platforms import current_platform
@@ -22,7 +23,7 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
@CustomOp.register("quant_fp8")
class QuantFP8(CustomOp):
"""
Quantize input tensor to FP8 (per-tensor, per-token, or per-group).
Quantize input tensor to FP8 (per-tensor, per-token, per-channel, or per-group).
This CustomOp supports both static and dynamic quantization.
"""
@@ -57,14 +58,14 @@ class QuantFP8(CustomOp):
self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode"
self.group_size = group_shape.col
else:
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, (
"Only per-tensor scales supported for static quantization."
)
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
if not static:
assert group_shape in (GroupShape.PER_TOKEN, GroupShape.PER_TENSOR), (
"Only per-token or per-tensor scales are supported for dynamic "
"non-group quantization."
)
def forward_cuda(
self,
@@ -72,8 +73,8 @@ class QuantFP8(CustomOp):
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
if self.is_group_quant and not self.static:
assert scale is None, "Dynamic group quantization does not use scale"
from vllm.model_executor.layers.quantization.utils import fp8_utils
return fp8_utils.per_token_group_quant_fp8(
@@ -90,12 +91,14 @@ class QuantFP8(CustomOp):
and self.group_shape == GroupShape.PER_TOKEN
and scale_ub.numel() == 1
)
return ops.scaled_fp8_quant(
x,
scale,
num_token_padding=self.num_token_padding,
scale_ub=scale_ub,
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
group_shape=self.group_shape if self.static else None,
)
def forward_hip(
@@ -131,8 +134,8 @@ class QuantFP8(CustomOp):
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
):
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
if self.is_group_quant and not self.static:
assert scale is None, "Dynamic group quantization does not use scale"
return self._quantize_group_native(x)
assert (scale is not None) == self.static
@@ -155,7 +158,10 @@ class QuantFP8(CustomOp):
# Even for dynamic per-token scales,
# reciprocal performs slightly better than division
out = x.to(torch.float32) * scale.reciprocal()
out = (
x.to(torch.float32)
* group_broadcast(scale.to(torch.float32), x.shape[-2:]).reciprocal()
)
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
# This currently generates an extra Triton kernel in compilation.

View File

@@ -158,11 +158,14 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
# If tensor has fewer dimensions than target shape, treat missing
# dimensions as size 1 (standard PyTorch broadcasting behavior)
t_dim_size = t.shape[i] if i < t.ndim else 1
if t_dim_size != s and t_dim_size != 1:
assert s % t_dim_size == 0
t = (
t.unsqueeze(i + 1)
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
.expand(*t.shape[: i + 1], s // t_dim_size, *t.shape[i + 1 :])
.flatten(i, i + 1)
)
return t
@@ -180,7 +183,16 @@ def scaled_quantize(
x: torch.Tensor,
group_shape: GroupShape,
quant_dtype: torch.dtype,
compute_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor to quantize
group_shape: Shape of quantization groups
quant_dtype: Target quantized dtype (e.g., torch.float8_e4m3fn)
compute_dtype: Optional dtype for intermediate computations.
If None, uses input dtype. Use torch.float32 for higher precision.
"""
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, (
"currently `scaled_quantize` only supports floating point dtypes "
@@ -189,11 +201,14 @@ def scaled_quantize(
finfo = torch.finfo(quant_dtype)
# Convert to compute dtype if specified
x_compute = x if compute_dtype is None else x.to(compute_dtype)
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
x_blkd = x_compute.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)