[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676)
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
@@ -563,6 +563,10 @@ def cdiv(a: int, b: int) -> int:
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
|
||||
def _generate_random_fp8(
|
||||
tensor: torch.Tensor,
|
||||
low: float,
|
||||
@@ -794,6 +798,12 @@ def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
|
||||
def align_to_256bytes(extent: int, dtype: torch.dtype) -> int:
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
eles_per_256bytes = 256 // dtype_size
|
||||
return round_up(extent, eles_per_256bytes)
|
||||
|
||||
|
||||
# `collections` helpers
|
||||
def is_list_of(
|
||||
value: object,
|
||||
|
||||
Reference in New Issue
Block a user