[ROCm][Bugfix] Disable hip sampler to fix deepseek's accuracy issue on ROCm (#32413)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2026-01-16 00:35:47 +08:00
committed by GitHub
parent 130d6c9514
commit 77c16df31d

View File

@@ -174,6 +174,8 @@ class TopKTopPSampler(nn.Module):
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# FIXME: Fix aiter_sampler's accuracy issue and remove this flag
DISABLE_AITER_SAMPLER = True
"""Optimized ROCm/aiter path (same structure as forward_cuda)."""
if (k is None and p is None) or generators:
if generators:
@@ -186,6 +188,8 @@ class TopKTopPSampler(nn.Module):
"processed_logits",
"processed_logprobs",
), "aiter sampler does not support returning logits/logprobs."
if DISABLE_AITER_SAMPLER:
return self.forward_native(logits, generators, k, p)
return self.aiter_sample(logits, k, p, generators), None
def aiter_sample(