Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -137,6 +137,11 @@ def rocm_unquantized_gemm_impl(
|
||||
|
||||
import math
|
||||
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
return gemm_a16w16(x, weight, bias)
|
||||
|
||||
use_skinny_reduce_counting = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx950()
|
||||
@@ -155,11 +160,6 @@ def rocm_unquantized_gemm_impl(
|
||||
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
|
||||
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||
|
||||
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
||||
|
||||
return gemm_a16w16(x, weight, bias)
|
||||
|
||||
use_skinny = (
|
||||
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
||||
and on_gfx9()
|
||||
|
||||
Reference in New Issue
Block a user