[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
Lucas Wilkinson
2025-02-04 21:22:24 -05:00
committed by GitHub
parent 233df6f5c4
commit 75e94309e8
10 changed files with 429 additions and 34 deletions

View File

@@ -563,6 +563,10 @@ def cdiv(a: int, b: int) -> int:
return -(a // -b)
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
def _generate_random_fp8(
tensor: torch.Tensor,
low: float,
@@ -794,6 +798,12 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()
def align_to_256bytes(extent: int, dtype: torch.dtype) -> int:
dtype_size = get_dtype_size(dtype)
eles_per_256bytes = 256 // dtype_size
return round_up(extent, eles_per_256bytes)
# `collections` helpers
def is_list_of(
value: object,