Use aiter triton fused_add_rmsnorm_pad for gpt-oss (#30976)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-01-28 14:47:47 -06:00
committed by GitHub
parent 3e440786af
commit 59bcc5b6f2
9 changed files with 327 additions and 11 deletions

View File

@@ -137,6 +137,11 @@ def rocm_unquantized_gemm_impl(
import math
if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias)
use_skinny_reduce_counting = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx950()
@@ -155,11 +160,6 @@ def rocm_unquantized_gemm_impl(
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
return out.reshape(*x.shape[:-1], weight.shape[0])
if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
return gemm_a16w16(x, weight, bias)
use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()

View File

@@ -187,7 +187,7 @@ class MLPBlock(torch.nn.Module):
)
else:
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)