Use Cache Hinting for fused_moe kernel (#15511)
This commit is contained in:
@@ -189,7 +189,11 @@ def fused_moe_kernel_gptq_awq(
|
|||||||
mask=token_mask[:, None] &
|
mask=token_mask[:, None] &
|
||||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
b = tl.load(b_ptrs)
|
b = tl.load(
|
||||||
|
b_ptrs,
|
||||||
|
cache_modifier=".cg",
|
||||||
|
eviction_policy="evict_last",
|
||||||
|
)
|
||||||
if use_int4_w4a16:
|
if use_int4_w4a16:
|
||||||
b = (b >> b_shifter) & 0xF
|
b = (b >> b_shifter) & 0xF
|
||||||
|
|
||||||
@@ -391,9 +395,13 @@ def fused_moe_kernel(
|
|||||||
mask=token_mask[:, None] &
|
mask=token_mask[:, None] &
|
||||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
other=0.0)
|
other=0.0)
|
||||||
b = tl.load(b_ptrs,
|
b = tl.load(
|
||||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
b_ptrs,
|
||||||
other=0.0)
|
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||||
|
other=0.0,
|
||||||
|
cache_modifier=".cg",
|
||||||
|
eviction_policy="evict_last",
|
||||||
|
)
|
||||||
# We accumulate along the K dimension.
|
# We accumulate along the K dimension.
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||||
|
|||||||
Reference in New Issue
Block a user