Simplify expression
This commit is contained in:
@@ -49,7 +49,7 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k
|
||||
smem_b_per_stage = block_n * block_k
|
||||
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
|
||||
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
|
||||
smem_barrier = num_stages * 8 * 2 if not is_wgrad else (num_stages + 1) * 8 * 2
|
||||
smem_barrier = (num_stages + int(is_wgrad)) * 8 * 2
|
||||
|
||||
smem_size = 0
|
||||
smem_size += smem_d
|
||||
|
||||
Reference in New Issue
Block a user